Git Product home page Git Product logo

Comments (18)

thakkarV avatar thakkarV commented on June 10, 2024

Sounds like you want a grouped gemm that supports gather/scatter? Have you taken a look at example 52 for inspiration? Happy to help with the design, but using CuTe is fairly natural with the extensions in Hopper gather/scatter GEMM

from cutlass.

akamiru avatar akamiru commented on June 10, 2024

Sadly I only have access to SM89 so I haven't looked into the Hooper examples too much. To be honest I'm still rather new to GPU computing and I'm kinda having a hard time understanding how cute/cutlass works. Especially the tiny sizes of my matrices don't seem to match the supported layouts.

Thinking about it a grouped GEMM with scatter/gather which runs on SM89 would pretty much allow me to do everything my trainer needs except for the first layer which is sparse. Would you be willing to discuss my network and help me make a plan how to tackle it via discord?

from cutlass.

thakkarV avatar thakkarV commented on June 10, 2024

I'm on vacation for a couple weeks but can help asynchronously on this thread. The concepts presented in example 52 are applicable pretty much 1:1 over here. I'm specifically referring to how the gather tensors layouts are transcribed. You can use the same custom indexed gather strides used in that example for your usecase. I recommend starting with a single file cute kernel like the cute tutorial

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

does https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu meet your need to run multiple GEMMs in parallel?

does https://github.com/NVIDIA/cutlass/blob/main/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu meet your need to fuse gather/scatter into a GEMM?

If yes to both, you could just merge ex36 into ex24 which is fairly easy. gather/scatter adds these three templates into the highest level templates. and these three arguments into top level arguments. So, what you need to do is to add these into group gemm interface.

group gemm and gather/scatter use the same underlying gemm implementation. so you just need to do some plumbing in the top levels device and kernel levels.

@shangz-ai

from cutlass.

jeromeku avatar jeromeku commented on June 10, 2024

@hwu36

I'm also interested in combining grouped gemms with gather / scatter fusion for Ampere (non-Hopper) architectures.

I see that both examples use the same underlying kernel GemmUniversal.

Other than adapting the top-level arguments and templates per your previous answer, what other changes do I need to make to the underlying code in GemmGrouped such that it properly constructs, initializes and passes the args for the individual gather / scatter kernels?

I.e., other than instantiating the underlying gemm kernel for the grouped gemm per above and adding the tensor indices to the grouped gemm arguments struct, what else needs to be done to pass these indices to the gather / scatter gemms? What I'm not clear on is how the top-level GemmGroup arguments interact with the underlying kernels (GemmUniversal).

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

@jackkosaian ^ ^

from cutlass.

jackkosaian avatar jackkosaian commented on June 10, 2024

@jeromeku,

It sounds like you've already figured out where new Arguments should be placed: here.

You'll also need to add them to the kernel's Params struct here, similar to how they are added for GemmUniversal::Params here (but noting that you'll have pointers to the gather/scatter index pointers for grouped GEMM since you want one list of indices per problem in the group).

Then, you need to augment the kernel::GemmGrouped::operator() to (1) determine which gather/scatter pointer to use for a given tile being processed by the thread block, and (2) pass these pointers down to lower levels of the kernel.

(1) can be achieved by indexing into the gather/scatter pointers in Params using problem_idx similar to what grouped GEMM does for A/B pointers here.

(2) can be done by following the pattern of use of the gather/scatter indices that is used by GemmUniversal::operator() (e.g., here.)

I hope this helps!

from cutlass.

jeromeku avatar jeromeku commented on June 10, 2024

@jackkosaian

Thanks for the clear explanation!

For the gather / scatter kernels, if I'm gathering rows of A (and scattering into rows of C / D), are there additional checks that need to be implemented to ensure that each gemm size is valid?

For example, if I'm using TensorOps, the minimum M is 16 as predetermined by the tensor core instruction and thus the warp and tile shapes need to be multiples of 16.

In the case that the gathered M doesn't meet this requirement, would can_implement need to be changed to ensure this condition by, e.g., padding to a minimum block size? Is this implemented in the codebase or are there related examples?

from cutlass.

jackkosaian avatar jackkosaian commented on June 10, 2024

Any sort of padding would need to be handled externally to can_implement. You would need to pad your tensors, problem shapes, etc. before setting them in the Arguments struct.

from cutlass.

jeromeku avatar jeromeku commented on June 10, 2024

Quick questions:

When filling a tensor with BlockFillSequential as such:

cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
        problem_size.mk()); 
cutlass::reference::host::BlockFillSequential(
        tensor_a.host_data(), problem_size.m() * problem_size.k());

For ElementInputA cutlass::half_t and m = k = 128 this gives a sequence only up to 2048, after which it repeats 2048 for the remaining values.

Guessing that I'm doing something wrong here...

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

there is no 2049 in fp16. you are doing correct thing.

from cutlass.

jeromeku avatar jeromeku commented on June 10, 2024

Thanks @hwu36

When I repeat the above using host::TensorFillSequential, I get an increasing sequence up to 16384 for m = k = 128 as expected, whereas for BlockFillSequential, the sequence stops at 2048 (and repeats 2048 thereafter).

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

if you check the result returned by host::TensorFillSequential, you wont see 2049.

from cutlass.

hwu36 avatar hwu36 commented on June 10, 2024

host::TensorFillSequential uses slightly different way to calculate, but after 2048, they both are limited by fp16.

in fp16,

2048+1=2048 <- BlockFillSequential way

fp16(2049)=2048 <- TensorFillSequential way

from cutlass.

jeromeku avatar jeromeku commented on June 10, 2024

@jackkosaian

Are there any examples of gather / scatter fusion and grouped_gemm specifically for Ampere architectures using Cutlass 3.0+ and CuTe?

How would one implement the above (combining gather / scatter and grouped_gemm) using the 3.0 API as opposed to the legacy 2.0 interface?

from cutlass.

jackkosaian avatar jackkosaian commented on June 10, 2024

Are there any examples of gather / scatter fusion and grouped_gemm specifically for Ampere architectures using Cutlass 3.0+ and CuTe?

We do not have examples of this.

How would one implement the above (combining gather / scatter and grouped_gemm) using the 3.0 API as opposed to the legacy 2.0 interface?

My suggestion would be to try to take a look at the CUTLASS 3 examples for gather/scatter and grouped GEMM (each of which currently target Hopper). You could consider adapting these to use SM80 CUTLASS 3 mainloops (similar to unit tests like this one). Note, however, that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API

from cutlass.

thakkarV avatar thakkarV commented on June 10, 2024

that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API

That said, do not be discouraged. carefully crafted data and thread layouts can hit 95% peak perf on 3.x mainloop as well :)

from cutlass.

mnicely avatar mnicely commented on June 10, 2024

@jeromeku has you question been answered?

from cutlass.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.