Git Product home page Git Product logo

Comments (36)

hwu36 avatar hwu36 commented on June 10, 2024

conv epilogue reuses gemm epilogue which assumes the output a row major dense matrix. the column number is K which is 1 in your case, and the row number is NPQ which is 100 in your case. dense matrix only have one stride, but your case would need actually 2 strides.

the hacky way to meet your need is to convince the epilogue that your output matrix is 10x10, not 1x100. stride info is passed to the epilogue here: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/implicit_gemm_convolution.h#L238 which is using https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/threadblock/output_iterator_parameter.h#L76 you can change this number to the stride you want rather than K which is what is used now. then in the epilogue iterator, try to hack this memory address calculation (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L397) to be what is right for you.

the non hacky way is to use this affine2 epilouge https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h. affine2 allows two strides in the output. an affine2 gemm example is here: https://github.com/NVIDIA/cutlass/tree/main/examples/18_ampere_fp64_tensorop_affine2_gemm

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

OK, i'll give it a try, but do you confirm that
"your case would need actually 2 strides" means that it is the same if I interpret my 10x10 matrix as 10 matrices of size 1x10 ? (it should work since it is a row filter).
But I understand with what you say that there is a "default" stride of 1 to advance to the next pixel, and whatever H or N, this would require a second stride. Correct ?

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

i would say that you would have to hack what ever that is not right.

the code is like

  for (col)
    *memory_ptr(row, col) = 

you essentially needs to change the mapping function from (row, col) to your new address calculation which uses your 2D stride [1, dst_stride]

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

It does not work (now, compilation fails).

By migrating to Affine2 layout, my tensor_src[ker|dst] type evolves from cutlass::TensorRef<float, cytlass::layout::TensorNHWC> to

  • cutlass::TensorRef<float, cutlass::layout::AffineRank2RowMajor> (for A=src)
  • cutlass::TensorRef<float, cutlass::layout::AffineRank2RowMajor> (for B=filter)
  • cutlass::TensorRef<float, cutlass::layout::AffineRankN<2> > (for C=D=dst)

Thus, Conv2dFprops::Arguments cannot be built since it requires the same type for A/B/C.

If I try to force dst to be cutlass::TensorRef<float, cutlass::layout::AffineRank2RowMajor> it is not better, because

  • Conv2dFprops::Arguments can be built but
  • implicit_gemm_op.initialize(...) fails to compile because of a no instance of constructor cutlass::MatrixCoord::MatrixCoord" matches the argument list( no member n, h, w, w, c)...

I don't want to use the "hacky way", I can't understand why CUTLASS makes it so hard to express a row stride in a basic 2D convolution.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

you need to plumb all the way through. A and B still nhwc. C and D is affine2. you have to map nhwc strides to affine2 strides somewhere.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

you need to plumb all the way through. A and B still nhwc. C and D is affine2. you have to map nhwc strides to affine2 strides somewhere.

I don't really understand what you are implying. The CUTLASS design is very hard to understand, and the template errors are a nightmare to analyze.
Here is the sample code migrated to cutlass::layout::AffineRankN<2> for the output, but it does not compile.
I can't understand why Arguments does not accept that configuration, it still wants a cutlass::layout::TensorNHWC for C and D while I explicitely defined it as a cutlass::layout::AffineRankN<2>

I don't have a clue to "remap" that.

//src, dst, kernelData are all in device memory
void convolutionCUTLASSRow(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  using ElementA = float;
  using ElementB = float;
  using ElementC = float;
  using ElementAccumulator = float;
  using ElementCompute = float;

  using Epilogue = cutlass::epilogue::thread::LinearCombination<
    ElementC,
    1,
    ElementAccumulator,
    ElementCompute
  >;

  using LayoutInputA = cutlass::layout::TensorNHWC;
  using LayoutInputB = cutlass::layout::TensorNHWC;
  using LayoutOutput = cutlass::layout::AffineRankN<2>;

  using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop <
    ElementA, LayoutInputA,
    ElementB, LayoutInputB,
    ElementC, LayoutOutput,
    ElementAccumulator,
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 8>,
    cutlass::gemm::GemmShape<64, 64, 8>,
    cutlass::gemm::GemmShape<1, 1, 1>,
    Epilogue,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    2,
    cutlass::arch::OpMultiplyAdd,
    cutlass::conv::IteratorAlgorithm::kAnalytic
  >::Kernel;

  using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;

  Conv2dFprop implicit_gemm_op;

  const int kernelDiameter = 2 * kernelRadius + 1;
  cutlass::Tensor4DCoord input_size(1, height, width, 1);
  cutlass::Tensor4DCoord filter_size(1, 1, kernelDiameter, 1);
  cutlass::Tensor4DCoord output_size(1, height, width, 1);

  cutlass::conv::Conv2dProblemSize problem_size(
    input_size,
    filter_size,
    cutlass::Tensor4DCoord(0, 0, kernelRadius, 0),
    cutlass::MatrixCoord(1, 1),
    cutlass::MatrixCoord(1, 1),
    output_size,
    cutlass::conv::Mode::kConvolution,
    1
  );

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  cutlass::layout::TensorNHWC src_layout(1, srcStrideInElements, height * srcStrideInElements);
  auto tensor_src = cutlass::make_TensorRef(const_cast<float*>(src), src_layout);

  cutlass::layout::TensorNHWC ker_layout(1, kernelDiameter, kernelDiameter);
  auto tensor_ker = cutlass::make_TensorRef(const_cast<float*>(kernelData), ker_layout);

  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));
  //cutlass::layout::TensorNHWC dst_layout(1, dstStrideInElements, height * dstStrideInElements);
  typename LayoutOutput::Stride::Index stride_factor_C[] = { 1, dstStrideInElements };
  auto dst_layout = cutlass::layout::Affine2Layout_Factory<LayoutOutput>::layout_factory(input_size, stride_factor_C);
  auto tensor_dst = cutlass::make_TensorRef(dst, dst_layout);

  using Arguments = typename Conv2dFprop::Arguments;
  Arguments arguments = Arguments(
    problem_size,
    tensor_src,
    tensor_ker,
    tensor_dst, //<-- DOES NOT COMPILE HERE
    tensor_dst),
    { 1.f, 0.f },
    cutlass::conv::SplitKMode::kSerial
  );

  cutlass::Status status;
  status = implicit_gemm_op.can_implement(arguments);

  size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  status = implicit_gemm_op.initialize(arguments, workspace.get(), stream);

  status = implicit_gemm_op();
}

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

we can add this in next release or next next release. @kerrmudgeon

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

I have a new issue, but I don't know if it's related. Let me know if I have to open a new Issue.
CUTLASS will unexpectedly hang after several calls.

Regarding the problem of "dst stride ignored", I have decided to circumvent it with an "unpacking" post-process step to restore the stride.
What I mean : let cutlass work with a "compact" dst (ignoring stride) and fill memory sequentially. Anyway it won't override anything useful. Then I will move rows in dst to restore the expected stride.

...
status = implicit_gemm_op.run(stream);
...
if (dstStride != width * sizeof(float))//dstStride is here in bytes, not in elements
{
  for (int rowIndex = 1; rowIndex < height; ++rowIndex)
  {
    const int reverseRowIndex = height - rowIndex;//we process rows in reverse order to move data and avoid overlapping
    const float* dstCompactedRow = dst + reverseRowIndex * width;
    float* dstStridedRow = reinterpret_cast<float*>(reinterpret_cast<unsigned char*>(dst) + reverseRowIndex * dstStride);
    if (!stream)
      cudaMemcpy(dstStridedRow, dstCompactedRow, width * sizeof(float), cudaMemcpyDeviceToDevice);
    else
      cudaMemcpyAsync(dstStridedRow, dstCompactedRow, width * sizeof(float), cudaMemcpyDeviceToDevice, stream);
  }//end for each row
}//end if (dstStride != width * sizeof(float))

Now that I was happy with CUTLASS processing data (I checked that the results were correct), I tried to use it to process several data. And what I observe is that it hangs.

Below is a code that simulates several calls by adding a loop.
It hangs on my machine (latest CUDA, NVidia driver, on a GeForce RTX 3050)
With an input of 1111x1024 (stride 4608 bytes), kernel radius = 10

I can't see anything wrong. is this a stride-related issue or something else ?

void convolutionCUTLASSRow(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  Conv2dFprop implicit_gemm_op;

  const int kernelDiameter = 2 * kernelRadius + 1;
  cutlass::Tensor4DCoord input_size(1, height, width, 1);
  cutlass::Tensor4DCoord filter_size(1, 1, kernelDiameter, 1);
  cutlass::Tensor4DCoord output_size(1, height, width, 1);

  cutlass::conv::Conv2dProblemSize problem_size(
    input_size,
    filter_size,
    cutlass::Tensor4DCoord(0, 0, kernelRadius, 0),
    cutlass::MatrixCoord(1, 1),
    cutlass::MatrixCoord(1, 1),
    output_size,
    cutlass::conv::Mode::kConvolution,
    1
  );

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  cutlass::layout::TensorNHWC src_layout(1, srcStrideInElements, height * srcStrideInElements);
  auto tensor_src = cutlass::make_TensorRef(const_cast<float*>(src), src_layout);

  cutlass::layout::TensorNHWC ker_layout(1, kernelDiameter, kernelDiameter);
  auto tensor_ker = cutlass::make_TensorRef(const_cast<float*>(kernelData), ker_layout);

  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));
  cutlass::layout::TensorNHWC dst_layout(1, dstStrideInElements, height * dstStrideInElements);
  auto tensor_dst = cutlass::make_TensorRef(dst, dst_layout);

  using Arguments = typename Conv2dFprop::Arguments;
  Arguments arguments = Arguments(
    problem_size,
    tensor_src,
    tensor_ker,
    tensor_dst,
    tensor_dst,
    { 1.f, 0.f },
    cutlass::conv::SplitKMode::kSerial
  );

  cutlass::Status status;
  status = implicit_gemm_op.can_implement(arguments);

  size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  printf("\r\n==============\r\n");
  status = implicit_gemm_op.initialize(arguments, workspace.get(), stream);
  for (int i = 0; i < 1000; ++i)
  {
    printf(".");
    fflush(stdout);
    status = implicit_gemm_op.update(arguments, workspace.get());
    status = implicit_gemm_op.run(stream);

    if (dstStride != width * sizeof(float))
    {
      for (int rowIndex = 1; rowIndex < height; ++rowIndex)
      {
        const int reverseRowIndex = height - rowIndex;
        const float* dstCompactedRow = dst + reverseRowIndex * width;
        float* dstStridedRow = dst + reverseRowIndex * dstStride / sizeof(float);
        cudaMemcpyAsync(dstStridedRow, dstCompactedRow, width * sizeof(float), cudaMemcpyDeviceToDevice, stream);
      }
    }
    //WILL HANG SOMEWHERE AROUND HERE
    cudaStreamSynchronize(stream);
  }//end for each i
  printf("\r\n==============\r\n");
}
//end convolutionCUTLASSRow()

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

you can try your problem size within cutlass profilers.

you can also check the return value of every cuda call to see what is wrong. it might just run out of device memory.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

OK, I got it.
I have been tricked by using a stream. The hang is just delayed : it is just the call to implicit_gemm_op.run(stream); that takes a few minutes to complete.
Did I misunderstand CUTLASS and is it really designed to handle a 21x1 2D convolution on a 1280x1024 image ? Or is this something that won't map efficiently on all those N, P, K, Q, R, S... from GEMM ?

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

your C is only 1 which is bad for vectorized operation including mma. you can try to use fixed channel (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/convolution.h#L114). example is in https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Ok. I have also been tricked by using a Debug build. For my 1111x1024 convolved by 21x1 :
Release : 17ms
Debug : 153166 ms (9000 times slower)
Ouch.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

your C is only 1 which is bad for vectorized operation including mma. you can try to use fixed channel (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/convolution.h#L114). example is in https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu

Does not seem to be available for ElementA=ElementB=ElementC=float ("incomplete type" at compilation)

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

what is after cutlass::conv::kernel::DefaultConv2dFprop?

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

what is after cutlass::conv::kernel::DefaultConv2dFprop?

I really need floats, because accuracy tests on my datasets shew me that half/bf16/tf32 were not precise enough for my data ranges.
So my cutlass::conv::kernel::DefaultConv2dFprop is the following one (the same one since the beginning of this issue) inferred from the tables of https://github.com/NVIDIA/cutlass/blob/main/media/docs/functionality.md

  using ElementA = float;
  using ElementB = float;
  using ElementC = float;
  using ElementAccumulator = float;
  using ElementCompute = float;

  using Epilogue = cutlass::epilogue::thread::LinearCombination<
    ElementC,
    1,
    ElementAccumulator,
    ElementCompute
  >;
  using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop <
    ElementA,
    cutlass::layout::TensorNHWC,
    ElementB,
    cutlass::layout::TensorNHWC,
    ElementC,
    cutlass::layout::TensorNHWC,
    ElementAccumulator,
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 8>,
    cutlass::gemm::GemmShape<64, 64, 8>,
    cutlass::gemm::GemmShape<1, 1, 1>,
    Epilogue,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    2,
    cutlass::arch::OpMultiplyAdd,
    cutlass::conv::IteratorAlgorithm::kAnalytic
  >::Kernel;

All the code that comes after is still visible in the first post of the current issue.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

2 -> 3
cutlass::conv::IteratorAlgorithm::kAnalytic -> cutlass::conv::IteratorAlgorithm::kFixedChannels

add below to the bottom of the list

    cutlass::conv::StrideSupport::kStrided,
    1,
    1

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Sure, that's exactly what I tried, but then the compilation fails.
I think it's related to the fact that I use float,float,float that currently implies cutlass::arch::OpClassSimt and not cutlass::arch::OpClassTensorOp like in the link you pointed here

1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu(21): error : incomplete type is not allowed
1>  using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop <
1>                                     ^
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(57): error : name followed by "::" must be a class or namespace name
1>    using ElementA = typename UnderlyingKernel::ElementA;
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(58): error : name followed by "::" must be a class or namespace name
1>    using LayoutA = typename UnderlyingKernel::LayoutA;
1>                             ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(59): error : name followed by "::" must be a class or namespace name
1>    using ElementB = typename UnderlyingKernel::ElementB;
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(60): error : name followed by "::" must be a class or namespace name
1>    using LayoutB = typename UnderlyingKernel::LayoutB;
1>                             ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(61): error : name followed by "::" must be a class or namespace name
1>    using ElementC = typename UnderlyingKernel::ElementC;
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(62): error : name followed by "::" must be a class or namespace name
1>    using LayoutC = typename UnderlyingKernel::LayoutC;
1>                             ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(63): error : name followed by "::" must be a class or namespace name
1>    using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator;
1>                                        ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(64): error : name followed by "::" must be a class or namespace name
1>    using ElementCompute = typename UnderlyingKernel::ElementCompute;
1>                                    ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(65): error : name followed by "::" must be a class or namespace name
1>    using OperatorClass = typename UnderlyingKernel::OperatorClass;
1>                                   ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(66): error : name followed by "::" must be a class or namespace name
1>    using ArchTag = typename UnderlyingKernel::ArchTag;
1>                             ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(67): error : name followed by "::" must be a class or namespace name
1>    using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape;
1>                                      ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(68): error : name followed by "::" must be a class or namespace name
1>    using WarpShape = typename UnderlyingKernel::WarpShape;
1>                               ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(69): error : name followed by "::" must be a class or namespace name
1>    using InstructionShape = typename UnderlyingKernel::InstructionShape;
1>                                      ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(70): error : name followed by "::" must be a class or namespace name
1>    using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle;
1>                                        ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(71): error : name followed by "::" must be a class or namespace name
1>    using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp;
1>                                      ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(72): error : name followed by "::" must be a class or namespace name
1>    static int const kStages = UnderlyingKernel::kStages;
1>                               ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(73): error : name followed by "::" must be a class or namespace name
1>    static int const kConvDim = UnderlyingKernel::kConvDim;
1>                                ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(74): error : name followed by "::" must be a class or namespace name
1>    using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator;
1>                                     ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(75): error : name followed by "::" must be a class or namespace name
1>    using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator;
1>                                     ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(76): error : name followed by "::" must be a class or namespace name
1>    using MathOperator = typename UnderlyingKernel::MathOperator;
1>                                  ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(78): error : name followed by "::" must be a class or namespace name
1>    static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator;
1>                                                                  ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(79): error : name followed by "::" must be a class or namespace name
1>    static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm;
1>                                                                       ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(80): error : name followed by "::" must be a class or namespace name
1>    static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport;
1>                                                               ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(81): error : name followed by "::" must be a class or namespace name
1>    static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode;
1>                                                       ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(84): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kM / WarpShape::kM) *
1>       ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(84): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kM / WarpShape::kM) *
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(85): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kN / WarpShape::kN) *
1>       ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(85): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kN / WarpShape::kN) *
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(86): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kK / WarpShape::kK);
1>       ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(86): error : name followed by "::" must be a class or namespace name
1>      (ThreadblockShape::kK / WarpShape::kK);
1>                              ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(89): error : name followed by "::" must be a class or namespace name
1>    using Arguments = typename UnderlyingKernel::Arguments;
1>                               ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu
1>
1>C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cutlass\include\cutlass/conv/device/implicit_gemm_convolution.h(94): error : name followed by "::" must be a class or namespace name
1>    typename UnderlyingKernel::Params params_;
1>             ^
1>          detected during instantiation of class "cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_> [with ImplicitGemmKernel_=Conv2dFpropKernel]" at line 103 of C:\Users\pierrechatelier\Desktop\TestOpenCV\TestOpenCV\cudaCUTLASS.cu

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

you can add simt specialization to https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/default_conv2d_fprop.h

it should be almost the same as

. Change IteratorAlgorithm::kOptimized to IteratorAlgorithm::kFixedChannels and use Conv2dFpropActivationTileAccessIteratorFixedChannels instead of Conv2dFpropActivationTileAccessIteratorOptimized

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Nope.
On one hand :

  • specialization works and compilation succeeds
  • speed is significantly improved

but on the other hand, far more critical :

  • now numerical results are wrong. Some groups of numbers can be spotted but misplaced. I can't identify a simple shift/roll/reverse/stride problem, it looks more complex.

It must be some deep incoherence between the specialization and the current conv2dfprop configuration regarding index computation.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

another thing you can try is to use tf32x3 to emulate fp32, thus you can use tensor cores.

https://github.com/NVIDIA/cutlass/blob/main/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

the key difference is this line https://github.com/NVIDIA/cutlass/blob/main/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu#L122

you may not see any benefit because of your C=1 though

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Interesting technique, but I think I am doomed.

At last, CUTLASS does not seem to be the right tool for basic 2D convolution. I expected speedup from the promise of loop unrollings and clever sharedmem usage, but I can't even run the thing properly, so...

    // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
    // all pointers, strides, and tensor extents must be divisible by 8 elements.
    //
    int const kAlignment = 4;

    if ((input_size.c() % kAlignment) || //<----- HERE MY C=1 WON'T WORK
      (filter_size.n() % kAlignment)) {

      // misaligned tensors
      return false;
    }

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

have you tried to set alignment to be 1? your data is fp32, i think alignment 1 is okay.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

A new failed attempt to produce a correct result :-(

What I did :

  • replace my previous Kernel with the 3xTF32 kernel example
  • in the Epilogue, use 1 instead of 128 / cutlass::sizeof_bits<ElementC>::value, (otherwise it will crash)

That's all ; as far as I understand, the 1xTF32 and double versions are just for comparison

  • it runs, but produces wrong results

Then in the new kernel I tried to use cutlass::conv::IteratorAlgorithm::kAnalytic instead of cutlass::conv::IteratorAlgorithm::kOptimized, but the results are wrong the same.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

you also need to change the input alignment to be 1 too. https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/default_conv2d_fprop.h#L81-L83

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

you also need to change the input alignment to be 1 too

You are right, and now it works.
Time for a recap.

The goal is a basic 2D convolution of input float 2D image (1 channel) with stride ; the convolution is separable as row+col filters

  • initial topic of this thread : stride is not currently easily supported for dst tensor; you made a proposal to work on it for future releases
  • using a kFixedChannel has two problems : it currently needs an additional kernel specialization, and gives wrong results. Some investigation might be needed
  • using a 3xTF32 with alignment=1 works and gives correct results (I did not try advanced accuracy tests but it seems ok at first sight)

Unfortunately, even with correct results, CUTLASS is really slower than any other 2D convolution algorithm I usually rely on


input is 1111x1024, kernel diameter is 2*10+1

OpenCV:
cv::Ptr<cv::cuda::Filter>row->apply:12.233000 ms
cv::Ptr<cv::cuda::Filter>col->apply:0.920200 ms

NPP
nppiFilterRowBorder_32f_C1R_Ctx:4.228600 ms
nppiFilterColumnBorder_32f_C1R_Ctx:2.226300 ms

CUTLASS
with the overhead of "uncompacting" dst to get correct stride
convolutionCUTLASSRow:17.082700 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)
convolutionCUTLASSCol:16.348200 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)

let dst compact:
convolutionCUTLASSRow:13.177100 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)
convolutionCUTLASSCol:13.244500 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)

with the overhead of "uncompacting" dst to get correct stride
convolutionCUTLASS3xTF32Row:26.002000 ms (TensorNHWC, OpClassTensorOp, Sm80, OpMultiplyAddFastF32, kOptimized, kStrided)
convolutionCUTLASS3xTF32Col:21.467700 ms (TensorNHWC, OpClassTensorOp, Sm80, OpMultiplyAddFastF32, kOptimized, kStrided)

let dst compact:
convolutionCUTLASSRow:19.143100 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)
convolutionCUTLASSCol:18.277300 ms (TensorNHWC, OpClassSimt, Sm80, OpMultiplyAdd, kAnalytic, kStrided)


So far my conclusion is that CUTLASS just does not fit my basic use case and is designed for different purposes.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

when you use kFixedChannel, did you set all alignments to 1?

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

when you use kFixedChannel, did you set all alignments to 1?

Absolutely :

using ElementA = float;
using ElementB = float;
using ElementC = float;
using ElementAccumulator = float;
using ElementCompute = float;
constexpr const int kChannelsCount = 1;

using Epilogue = cutlass::epilogue::thread::LinearCombination<
  ElementC,
  1,
  ElementAccumulator,
  ElementCompute
>;


using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop <
  ElementA,
  cutlass::layout::TensorNHWC,
  ElementB,
  cutlass::layout::TensorNHWC,
  ElementC,
  cutlass::layout::TensorNHWC,
  ElementAccumulator,
  cutlass::arch::OpClassSimt,
  cutlass::arch::Sm80,
  cutlass::gemm::GemmShape<128, 128, 8>,
  cutlass::gemm::GemmShape<64, 64, 8>,
  cutlass::gemm::GemmShape<1, 1, 1>,
  Epilogue,
  cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  3,//2
  cutlass::arch::OpMultiplyAdd,
  cutlass::conv::IteratorAlgorithm::kFixedChannels,
  cutlass::conv::StrideSupport::kStrided,
  kChannelsCount,
  kChannelsCount
> ::Kernel;

from cutlass.

mnicely avatar mnicely commented on June 10, 2024

@chacha21 has you question been answered?

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

@chacha21 has you question been answered?

Not really. You tell me.
See the recap a few messages above (reproduced just below) :

_
The goal is a basic 2D convolution of input float 2D image (1 channel) with stride ; the convolution is separable as row+col filters

  • initial topic of this thread : stride is not currently easily supported for dst tensor; you made a proposal to work on it for future releases
  • using a kFixedChannel has two problems : it currently needs an additional kernel specialization, and gives wrong results. Some investigation might be needed
  • using a 3xTF32 with alignment=1 works and gives correct results (I did not try advanced accuracy tests but it seems ok at first sight)
    Unfortunately, even with correct results, CUTLASS is really slower than any other 2D convolution algorithm I usually rely on

_

in other words :

  • I raised a bug/feature request (about dst stride), it might be resolved in the future
  • The kFixedChannel usage is not resolved and does not work as advised. I don't know if it's a bug.
  • I mentioned the performance of CUTLASS in the convolution2D scenario being far below what I expected. I don't know if you consider that a normal behavior or not.

So if you consider that those points are "not a bug/won't fix", we can close this thread.
If you consider that those points are relevant for some PR, we can leave it open.

from cutlass.

github-actions avatar github-actions commented on June 10, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

so what's the status of these (multiple) issues ?

from cutlass.

github-actions avatar github-actions commented on June 10, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Are there any fixes in sight ?

from cutlass.

github-actions avatar github-actions commented on June 10, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

from cutlass.

chacha21 avatar chacha21 commented on June 10, 2024

Are there any fixes in sight ?

from cutlass.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.