Comments (9)
Karthikeyan Manivannan can take a look at what's going on with Triton.
from pytorch.
Repros on A100.
from pytorch.
I am seeing the same number of registers in both cases:
[~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump1 python ~/work/issues/126463/softmax_bwd_k1.py
14efd82a05f3d6020f6e683159d9f173ce520e99f2797052912a4e8a7c182d60 rblock=2048 kernel.n_regs=40 ms=10.224
a979388b7845426248da805049b53870c5afb5499771d04b7b3b953214069c7d rblock=1024 kernel.n_regs=32 ms=9.156
[~/triton (main)]$ TRITON_CACHE_DIR=$HOME/triton-cache/dump2 python ~/work/issues/126463/softmax_bwd_k2.py
365f353e83c5ddc165a630bcdc9f4ca005601ec6d3d8147823186d9854336675 rblock=2048 kernel.n_regs=40 ms=11.312
718152d0e5ad5222fa796d108f80924ea227699ca5b2edc6b8f6e36eca5bdc80 rblock=1024 kernel.n_regs=32 ms=9.141
[~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump1/14efd82a05f3d6020f6e683159d9f173ce520e99f2797052912a4e8a7c182d60/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers
[~/triton (main)]$ ptxas --gpu-name=sm_90a -v ~/triton-cache/dump2/365f353e83c5ddc165a630bcdc9f4ca005601ec6d3d8147823186d9854336675/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers
This output is from 4faa131964a93d5022e284ceb99957f714e2d4e5 and I see the same behavior on 8e48e4fa454e652438f41ff00cbb9ed38485e0f8
from pytorch.
Can you check the current pined triton in pytorch 45fff310c891f5a92d55445adf8cc9d29df5841e ?
from pytorch.
I am seeing the same the same number of registers across both kernels on the pinned hash (45fff310) too.
[~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump1/9b6fa6e98c5d58e5c3ff5d322a33db7a09b12531893a0ecf516dc9ddafc58cff/triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers
[~/triton (45fff310)]$ ptxas --gpu-name=sm_90a -v ~//triton-cache/dump2/224ba495c345498737865c572e18a1c439b5b194a7d53044dc2bf7d82aaac31d/triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2.ptx
ptxas info : 0 bytes gmem
ptxas info : Compiling entry function 'triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2' for 'sm_90a'
ptxas info : Function properties for triton_red_fused__log_softmax__log_softmax_backward_data_nll_loss_backward_nll_loss_forward_2
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 40 registers
from pytorch.
hmm, very interesting. One thing I can think of is, the tests mentioned in the issue summary are ran on A100, while you seems to see different behavior on H100.
Can you try on an A100 to see if you repro?
from pytorch.
The issue seems arise from how ptxas is using registers on A100. For k2, Triton is producing the same ptx file for both H100 and A100. For whatever reason, ptxas uses fewer registers on A100.
from pytorch.
For whatever reason, ptxas uses fewer registers on A100.
Not sure how far we can go. But what versions of ptxas are you using on A100 and H100? I'm wondering if it's due to different ptxas version or due to different GPUs.
from pytorch.
I ran with the same ptxas version(V12.1.105) on both H100 and A100. Setting --gpu-name=sm_80 reduces register usage as compared to setting --gpu-name=sm_90 for the same k2 bs=2048 ptx file.
from pytorch.
Related Issues (20)
- Improve type inference in torchscript
- Inconsistency `CrossEntropyLoss` vs `BCELoss` regarding logits/probability space
- torch.onnx.export fails for torch.meshgrid with indexing='xy'
- 'make html' will print 'duplicate object description' warnings when there are 1~5 CPUs in the running machine
- torch.jit.script not work for ParameterDict.items HOT 1
- [bug]When the variable in the dataset is of type cuda, calling the dataloader will result in 0
- torch.onnx.export with dynamic axes fails for torch.nn.InstanceNorm1d with track_running_stats=True
- [MITIGATED] Migration to Amazon Linux 23 - issues with nvidia driver
- crash@sleef_tryVXE2 () while trying to run torch.compile() BERT model HOT 1
- torch.onnx.export - `repeat_interleave` produces invalid model HOT 1
- [inductor][cpu] AMP models static/dynamic default/CPP wrapper accuracy/performance crash in 2024-06-08 nightly release HOT 1
- elements of STL get converted to an `IValue` and cannot be converted to an STL.
- [inductor][cpu]LayoutLMForSequenceClassification AMP single threadstatic/dynamic shape default/CPP wrapper accuracy failure HOT 3
- [inductor][cpu]hf_BigBird AMP multiple thread dynamic shape default wrapper performance regression HOT 1
- subscript on top of range not handled correctly in dynamo HOT 2
- test/dynamo/test_torchrec.py doesn't run in OSS CI (and won't run locally)
- torch.library.custom_op doesn't work with multithreading
- torch.library.custom_op's needs_input_grad is wrong with TensorList inputs
- [torch.compile] Llama2 failure using dynamic shapes with Torch 2.4 nightly
- DISABLED test_segfault (__main__.TestDataLoader) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.