Comments (18)
Sounds like you want a grouped gemm that supports gather/scatter? Have you taken a look at example 52 for inspiration? Happy to help with the design, but using CuTe is fairly natural with the extensions in Hopper gather/scatter GEMM
from cutlass.
Sadly I only have access to SM89 so I haven't looked into the Hooper examples too much. To be honest I'm still rather new to GPU computing and I'm kinda having a hard time understanding how cute/cutlass works. Especially the tiny sizes of my matrices don't seem to match the supported layouts.
Thinking about it a grouped GEMM with scatter/gather which runs on SM89 would pretty much allow me to do everything my trainer needs except for the first layer which is sparse. Would you be willing to discuss my network and help me make a plan how to tackle it via discord?
from cutlass.
I'm on vacation for a couple weeks but can help asynchronously on this thread. The concepts presented in example 52 are applicable pretty much 1:1 over here. I'm specifically referring to how the gather tensors layouts are transcribed. You can use the same custom indexed gather strides used in that example for your usecase. I recommend starting with a single file cute kernel like the cute tutorial
from cutlass.
does https://github.com/NVIDIA/cutlass/blob/main/examples/24_gemm_grouped/gemm_grouped.cu meet your need to run multiple GEMMs in parallel?
does https://github.com/NVIDIA/cutlass/blob/main/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu meet your need to fuse gather/scatter into a GEMM?
If yes to both, you could just merge ex36 into ex24 which is fairly easy. gather/scatter adds these three templates into the highest level templates. and these three arguments into top level arguments. So, what you need to do is to add these into group gemm interface.
group gemm and gather/scatter use the same underlying gemm implementation. so you just need to do some plumbing in the top levels device and kernel levels.
from cutlass.
I'm also interested in combining grouped gemms with gather / scatter fusion for Ampere (non-Hopper) architectures.
I see that both examples use the same underlying kernel GemmUniversal.
Other than adapting the top-level arguments and templates per your previous answer, what other changes do I need to make to the underlying code in GemmGrouped such that it properly constructs, initializes and passes the args for the individual gather / scatter kernels?
I.e., other than instantiating the underlying gemm
kernel for the grouped gemm per above and adding the tensor indices to the grouped gemm arguments struct, what else needs to be done to pass these indices to the gather / scatter gemms? What I'm not clear on is how the top-level GemmGroup arguments interact with the underlying kernels (GemmUniversal).
from cutlass.
@jackkosaian ^ ^
from cutlass.
It sounds like you've already figured out where new Arguments
should be placed: here.
You'll also need to add them to the kernel's Params
struct here, similar to how they are added for GemmUniversal::Params
here (but noting that you'll have pointers to the gather/scatter index pointers for grouped GEMM since you want one list of indices per problem in the group).
Then, you need to augment the kernel::GemmGrouped::operator()
to (1) determine which gather/scatter pointer to use for a given tile being processed by the thread block, and (2) pass these pointers down to lower levels of the kernel.
(1) can be achieved by indexing into the gather/scatter pointers in Params
using problem_idx
similar to what grouped GEMM does for A/B pointers here.
(2) can be done by following the pattern of use of the gather/scatter indices that is used by GemmUniversal::operator()
(e.g., here.)
I hope this helps!
from cutlass.
Thanks for the clear explanation!
For the gather / scatter
kernels, if I'm gathering rows of A
(and scattering into rows of C / D
), are there additional checks that need to be implemented to ensure that each gemm
size is valid?
For example, if I'm using TensorOps
, the minimum M
is 16
as predetermined by the tensor core instruction and thus the warp and tile shapes need to be multiples of 16
.
In the case that the gathered M
doesn't meet this requirement, would can_implement
need to be changed to ensure this condition by, e.g., padding to a minimum block size? Is this implemented in the codebase or are there related examples?
from cutlass.
Any sort of padding would need to be handled externally to can_implement
. You would need to pad your tensors, problem shapes, etc. before setting them in the Arguments
struct.
from cutlass.
Quick questions:
When filling a tensor with BlockFillSequential
as such:
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk());
cutlass::reference::host::BlockFillSequential(
tensor_a.host_data(), problem_size.m() * problem_size.k());
For ElementInputA
cutlass::half_t
and m = k = 128
this gives a sequence only up to 2048
, after which it repeats 2048
for the remaining values.
Guessing that I'm doing something wrong here...
from cutlass.
there is no 2049 in fp16. you are doing correct thing.
from cutlass.
Thanks @hwu36
When I repeat the above using host::TensorFillSequential
, I get an increasing sequence up to 16384
for m = k = 128
as expected, whereas for BlockFillSequential
, the sequence stops at 2048
(and repeats 2048
thereafter).
from cutlass.
if you check the result returned by host::TensorFillSequential
, you wont see 2049.
from cutlass.
host::TensorFillSequential
uses slightly different way to calculate, but after 2048, they both are limited by fp16.
in fp16,
2048+1=2048 <- BlockFillSequentia
l way
fp16(2049)=2048 <- TensorFillSequential
way
from cutlass.
Are there any examples of gather / scatter fusion
and grouped_gemm
specifically for Ampere
architectures using Cutlass 3.0+
and CuTe
?
How would one implement the above (combining gather / scatter
and grouped_gemm
) using the 3.0
API as opposed to the legacy 2.0
interface?
from cutlass.
Are there any examples of gather / scatter fusion and grouped_gemm specifically for Ampere architectures using Cutlass 3.0+ and CuTe?
We do not have examples of this.
How would one implement the above (combining gather / scatter and grouped_gemm) using the 3.0 API as opposed to the legacy 2.0 interface?
My suggestion would be to try to take a look at the CUTLASS 3 examples for gather/scatter and grouped GEMM (each of which currently target Hopper). You could consider adapting these to use SM80 CUTLASS 3 mainloops (similar to unit tests like this one). Note, however, that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API
from cutlass.
that GEMMs produced via the CUTLASS 3 API for CC < 90 are not currently as well optimized as those produced via the CUTLASS 2 API
That said, do not be discouraged. carefully crafted data and thread layouts can hit 95% peak perf on 3.x mainloop as well :)
from cutlass.
@jeromeku has you question been answered?
from cutlass.
Related Issues (20)
- [QST] `TensorView` API HOT 4
- [QST] The performance of Hopper group gemm is not meeting expectation in some cases HOT 2
- [QST]Composition does not work as expected. HOT 3
- [QST] could you please help me understand how right_inverse work? HOT 2
- [QST] `retile_D` in mainloop HOT 2
- [QST] `TiledCopy` HOT 2
- CuTe documentation -- 02_layout_algebra.md has a wrong example in Complement section HOT 3
- [QST] How to pass a cute::Tensor as parameter to a device function? HOT 4
- [QST] `SmemLayoutAtom` Layouts HOT 3
- [QST] MMA_Traits shape mismatch HOT 9
- [QST] `Cutlass 3.0` `GemmUniversal` Interface HOT 6
- [BUG] Error Internal with large batch size in gemm::device::GemmBatched HOT 3
- [QST] how can i do w4a8 (int4 * int8) using cutlass? HOT 4
- [QST] Checking Tensor Shape returns 0 HOT 4
- [BUG] Build failed with nvcc "-G" on H100 HOT 3
- [cute gemm] HOT 2
- [QST] Constraint of Tensor Layout in FP8 GEMM Kernel HOT 4
- [QST] Build hangs on AWS p3.2xlarge Ubuntu HOT 7
- [QST] Epilogue Swizzle HOT 3
- [QST] Confusing stride definition in example_57
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cutlass.