Comments (9)
Adding oncall: pt2
, perhaps we can chat a bit more about Volta coverage for torch.compile...
@RobertCsordas have you filed an issue against trition? (As it does not have much to do with the PyTorch)
from pytorch.
triage review to discuss Volta coverage for torch.compile, this issue implies that it might not work
from pytorch.
Hi! @malfet: no, I haven't. It is unclear to me who is maintaining the PyTorch Triton version, as it seems to be different from the current official upstream. But I will open an issue there as well.
This example worked perfectly with 2.1. There was even a 2.2 dev version where it worked with my custom Triton kernel together with torch.compile().
from pytorch.
cc @bertmaher
from pytorch.
Ok, I've tried it on colab and can not reproduce the failure using 2.3 and therefore verbatim example from https://github.com/triton-lang/triton/blob/release/2.3.x/python/tutorials/03-matrix-multiplication.py
On T4 : https://colab.research.google.com/drive/1GnTzDW7aKWofIYjKE0X5XgkzBETO3KhT?usp=sharing
from pytorch.
Interesting. I'm able to reproduce this on 4 different machines (every single machine I was testing on). I'm running drivers 530.41.03 with CUDA 12.1.1, 535.104.12 with CUDA 12.2.140, 535.113.01 with CUDA 11.5, 535.129.03 and 11.5.117. One of the machines has a Titan V, the rest are V100s, both 32 and 16gb versions. I started a clean install of python with just triton installed, the situation is identical. Python versions are 3.10.10 and 3.10.12.
Do you have any suggestions what should I try or how can we continue debugging?
from pytorch.
I simplified the matmul code a bit (removed leaky relu) and left just one config to guarantee equivalence, and dumped the PTX with the working 2.1 and the broken 2.2 Triton. Maybe this can help with debugging. Code: https://pastebin.com/FAL22dH1
Tirton 2.1 ptx (working): https://pastebin.com/6E0wiVbb
Triton 2.2 ptx (broken): https://pastebin.com/XMNJgZYB
I don't speak PTX, but to me it look like the Triton 2.2 PTX is completely missing the code that should call the tensor cores (the 2.1 PTX has a bunch of mma.sync.aligned.m8n8k4 instructions, while the 2.2 one has 0). The invalid 2.2 code is also significantly shorter.
The Triton 2.3.1 PTX is identical to that of 2.2.
EDIT: updated code to dump the TTIR and TTGIR as well: https://pastebin.com/3TtEEPiG
2.1 TTIR: https://pastebin.com/Lz6r02Ft
2.1 TTGIR: https://pastebin.com/0ukha35H
2.1 LLIR: https://pastebin.com/xzZSDs06
2.2 TTIR: https://pastebin.com/06mP1j1j
2.2 TTGIR: https://pastebin.com/k5FfpT7K
2.2 LLIR: https://pastebin.com/FAL22dH1
The TTIRs seems identical, except the register numbers, and 2 instructions:
2.2 has
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = arith.select %9, %8, %c8_i32 : i32
while 2.1:
%9 = arith.minsi %8, %c8_i32 : i32
The ordering and reg numbers in the TTGIR are different, but the general gist seems to be similar. The thing that forgets to do the tt.dot seems to come after these IRs.
The 2.2 LLIRs don't have mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32's, while the 2.1 LLIR has them.
EDIT 2:
I ran the unit tests on (branch release/2.2.x from github), and they fail after some time with segmentation fault, but the results in the mean time are here: https://pastebin.com/PSHDbUs0
GEMM tends to fail.
The results of lit test: https://pastebin.com/MrNRwyqT
I can't run the Ninja test because it searches cmake in /tmp and it can't find it, and I have not yet figured out how to fix it.
EDIT 3: added missing LLIRs
from pytorch.
Btw the success on T4 is not a good test, because even though the compute capability is 7.5, it's actually Turing. Triton uses a different dot implementation for Turing and Volta:
It's even a different file: https://github.com/triton-lang/triton/blob/release/2.2.x/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv1.cpp is used only for Volta, and all the rest (including Turing) uses https://github.com/triton-lang/triton/blob/release/2.2.x/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp
from pytorch.
Actually the bug is a in line: https://github.com/triton-lang/triton/blob/0e7b97bd47fc4beb21ae960a516cd9a7ae9bc060/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv1.cpp#L74
ALayout should be BLayout. Then I found that this is already fixed on main (triton-lang/triton@cc3aed4)
I can't build the main because of the following error:
downloading and extracting https://tritonlang.blob.core.windows.net/llvm-builds/llvm-657ec732-ubuntu-x64.tar.gz ...
error: HTTP Error 403: This request is not authorized to perform this operation.
but if I change that char in 2.2, it works just fine.
from pytorch.
Related Issues (20)
- The order of the parameters of `nn.Conv1d()`, `nn.Conv2d()` and `nn.Conv3d()` should be explained in the actual order of the parameters. HOT 1
- DISABLED test_grad_scaler_with_preset_grad_scale_in_place_unscale_True_Adam_cuda_float32 (__main__.TestCudaOptimsCUDA) HOT 2
- ☂️ 40+ MacOS tests were marked flaky recently HOT 3
- torch.compile with mode = "max-autotune" breaks when starting from inference_mode
- DISABLED test_comprehensive_special_xlog1py_cpu_float32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_nn_functional_conv3d_cpu_float32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_rand_like_cpu_float32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_roll_cpu_float16 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_scatter_add_cpu_float16 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_polygamma_polygamma_n_4_cpu_int32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_special_scaled_modified_bessel_k0_cpu_int32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_ones_cpu_float32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_special_modified_bessel_k1_cpu_int64 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_polygamma_polygamma_n_2_cpu_int64 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_ones_like_cpu_float16 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_signbit_cpu_bool (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_remainder_cpu_int32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_norm_inf_cpu_float64 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_reciprocal_cpu_int32 (__main__.TestInductorOpInfoCPU) HOT 2
- DISABLED test_comprehensive_repeat_cpu_int64 (__main__.TestInductorOpInfoCPU) HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.