Comments (17)
@manman-ren If there are any ways to rewrite the kernel to improve perf we'd also be interested in that.
from pytorch.
Looking at the NCU profiles, I'm seeing thread occupancy metrics are very close for the two kernels, so I guess register pressure doesn’t affect much here.
With the masking we have extra data movement, which are all in registers. But I’m seeing memory throughput is different.
I'm also seeing instruction execution counts are quite different:
from pytorch.
By disabling LSR and moving the arange instructions
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
we can get to 0.007877.
bad.py --> bad2.py --> bad4.py + disable LSR
0.009078 0.008732 0.007877
--> -13%
Actually if we only disable LSR with the actual masks, we get
FA2: 6.420323848724365
templated attention: 7.937378406524658
With LSR
FA2: 6.392755508422852
templated attention: 9.197888374328613
Looks like LSR plays a significant role when the register usage is high.
Patch to disable LSR:
diff --git a/python/src/llvm.cc b/python/src/llvm.cc
index 3b7f8fa30..a21ee2860 100644
--- a/python/src/llvm.cc
+++ b/python/src/llvm.cc
@@ -55,6 +55,14 @@ std::string translateLLVMIRToASM(llvm::Module &module,
*optPtr = true;
}
}
+ {
+ auto optIt = options.find("disable-lsr");
+ if (optIt != options.end()) {
+ auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
+ *optPtr = true;
+ llvm::errs() << "disable LSR ----------------\n";
+ }
+ }
// inline everything
for (llvm::Function &f : module.functions())
from pytorch.
We will take a look!
from pytorch.
Which version of pytorch should I use?
I tried master, but hit
return torch.where(patch_bidirectional | causal_mask, score, -float("inf"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: where() received an invalid combination of arguments - got (bool, Tensor, float), but expected one of:
from pytorch.
@manman-ren sorry my script was a little bit wrong - tried to manually patch in some comments but forgot to actually comment them. Updated the script
from pytorch.
Can repro this on A100:
FA2: 6.392755508422852
templated attention: 9.197888374328613
With patch that sets causal_mask to True:
FA2: 6.416097640991211
templated attention: 7.2350616455078125
I don't see any obvious issue with the optimizations. For the two ttgir attached in the description, the only thing I notice is that a mulf is done outside of the loop for the good case, and it is inside the loop for the bad case. But when I repro locally, this mulf doesn't exist in both versions. The other differences are just calculating mask and using mask.
When looking at ncu profiles, the bad case has more stalls.
The register usage also increases from 224 to 255. But the occupancy is similar. So this looks like an issue with register pressure and scheduling? By scheduling, I mean possibilities of moving the mask calculation so it will not be blocking.
Comparing the stalls for good vs. bad
It seems this kernel doesn't have autotuning, it performs 2 dot operations with [m,n,k] of [128,64,64].
@htyu We can take a look at the profiles together and see if you can spot something that I may have missed :]
from pytorch.
Let me see if I can modify the source to make it faster.
from pytorch.
We can try to move the mask calculation ahead:
diff bad.py bad2.py
--- bad.py 2024-04-18 16:36:19.578471966 -0700
+++ bad2.py 2024-04-18 16:35:14.303964448 -0700
@@ -131,8 +131,18 @@
# loop over k, v and update accumulator
lo = 0
hi = N_CTX
+ tmp0 = tl.full([1], 1024, tl.int64)
+ tmp1 = (offs_m[:, None]) <= tmp0 # BLOCK_M, 1
+
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
+ tmp2 = (start_n + offs_n[None, :]) <= tmp0 # 1, BLOCK_N
+ tmp3 = tmp1 & tmp2
+ tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
+ tmp5 = tmp3 | tmp4
+ tmp6 = float("-inf")
+ tmp7 = tmp6.to(tl.float32)
+
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
@@ -141,14 +151,14 @@
qk += tl.dot(q, k.to(MATMUL_PRECISION))
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
- tmp0 = tl.full([1], 1024, tl.int64)
- tmp1 = (offs_m[:, None]) <= tmp0
- tmp2 = (start_n + offs_n[None, :]) <= tmp0
- tmp3 = tmp1 & tmp2
- tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
- tmp5 = tmp3 | tmp4
- tmp6 = float("-inf")
- tmp7 = tmp6.to(tl.float32)
+ #tmp0 = tl.full([1], 1024, tl.int64)
+ #tmp1 = (offs_m[:, None]) <= tmp0
+ #tmp2 = (start_n + offs_n[None, :]) <= tmp0 # 1, BLOCK_N
+ #tmp3 = tmp1 & tmp2
+ #tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
+ #tmp5 = tmp3 | tmp4
+ #tmp6 = float("-inf")
+ #tmp7 = tmp6.to(tl.float32)
tmp8 = tl.where(tmp5, (qk), tmp7)
qk = tmp8
Seems to be a bit faster
CUDA_VISIBLE_DEVICES=6 python3 bad2.py
0.008732
CUDA_VISIBLE_DEVICES=6 python3 bad.py
0.009078
We can also try to rematerialize address calculations to see if it affects register pressure. You are already using tl.make_block_ptr, which has some weird interactions with register pressure (in most cases, good interaction).
from pytorch.
By disabling LSR and moving the arange instructions
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N)
we can get to 0.007877.
bad.py --> bad2.py --> bad4.py + disable LSR 0.009078 0.008732 0.007877 --> -13%
Actually if we only disable LSR with the actual masks, we get FA2: 6.420323848724365 templated attention: 7.937378406524658 With LSR FA2: 6.392755508422852 templated attention: 9.197888374328613
Looks like LSR plays a significant role when the register usage is high.
Patch to disable LSR:
diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 3b7f8fa30..a21ee2860 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -55,6 +55,14 @@ std::string translateLLVMIRToASM(llvm::Module &module, *optPtr = true; } } + { + auto optIt = options.find("disable-lsr"); + if (optIt != options.end()) { + auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second); + *optPtr = true; + llvm::errs() << "disable LSR ----------------\n"; + } + } // inline everything for (llvm::Function &f : module.functions())
Nice find! What does the PTX code and the profile look like after disabling LSR? Wondering how it affects code quality.
from pytorch.
@manman-ren btw, I merged this PR today: #124356 which improves the baseline numbers. There is still a significant gap but this might change some analysis.
from pytorch.
I reran the tests (good vs. bad) after rebasing pytorch to latest
CUDA_VISIBLE_DEVICES=6 python bad.py
disable LSR ----------------
FA2: 6.40308141708374
templated attention: 7.860274791717529
do not disable LSR
FA2: 6.46448278427124
templated attention: 7.766917705535889
good.py (remove mask) do not disable LSR
FA2: 6.471131801605225
templated attention: 6.987882137298584
The gap is much smaller now: 6.987882137298584 vs 7.766917705535889
And the impact of disabling LSR went from -17% to +1%.
I will try to figure out why LSR sometimes has such a big negative perf impact.
About this change in the template
- acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
+ acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
it shouldn't change anything for bf16 or fp16, right? @htyu
from pytorch.
About this change in the template
- acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION)) + acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
it shouldn't change anything for bf16 or fp16, right? @htyu
Embedding acc in dot operation, which is basically an FMA, loses precision actually. We can think of it as a fast-math mode.
from pytorch.
Embedding acc in dot operation, which is basically an FMA, loses precision actually.
Loses precision or gains precision? FMA usually gains precision, no?
from pytorch.
Embedding acc in dot operation, which is basically an FMA, loses precision actually.
Loses precision or gains precision? FMA usually gains precision, no?
IIUC, FMA loses precision compared to separate mul and add instructions.
from pytorch.
@htyu I don't think that's true. See https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
Fused multiply–add can usually be relied on to give more accurate results.
It can certainly cause issues at times (see https://dev-discuss.pytorch.org/t/fmas-and-softmax-and-floating-point-considered-harmful/1953), but specifically, the instruction itself should lead to "increased precision".
from pytorch.
@htyu I don't think that's true. See https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
Fused multiply–add can usually be relied on to give more accurate results.
It can certainly cause issues at times (see https://dev-discuss.pytorch.org/t/fmas-and-softmax-and-floating-point-considered-harmful/1953), but specifically, the instruction itself should lead to "increased precision".
You are right on FMA increasing accuracy in general, as it avoids one extra rounding.
It looks like also depending on hardware implementation. Found some code pointers related to fp8 dot:
https://github.com/openai/triton/blob/792315bb76211ae0dcef535f426d531ebc0fe2ff/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#L434
- acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
- acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
As for the codegen, without acc, there will be separate fadd
instructions generated.
from pytorch.
Related Issues (20)
- DISABLED test_binary_op_list_slow_path__foreach_div_cuda_bool (__main__.TestForeachCUDA) HOT 1
- [DeviceMesh] Add support for `group: Tuple[ProcessGroup, ...]` in `from_group()` HOT 2
- DISABLED test_n_threads (__main__.TestOpenMP_ParallelFor) HOT 1
- DISABLED test_profiler_rpc_key_names (__main__.TensorPipeRpcTest) HOT 1
- DISABLED test_full_tensor_sync (__main__.DTensorTest) HOT 1
- DISABLED test_noncontiguous_samples_special_bessel_y1_cuda_int64 (__main__.TestCommonCUDA) HOT 1
- torch.tensor call with list of tensors fails with AssertionError: pending {u0} not in FakeTensor
- Weird AST constructor issue with mode="max-autotune" with python 3.11 HOT 8
- torch.onnx.dynamo_export fails to convert torchaudio.transforms.MFCC to onnx
- insert_deferred_runtime_asserts does not work with modulus HOT 2
- Run mkldnn matmul in SPR for bf32
- `torch.compile` gives correct index values (if those are returned), but not the indexed values. HOT 5
- torch.multinomial raises no error when sampling from zero weight and replacement=False. HOT 1
- Support for Kolmogorov-Arnold Networks (KANs) HOT 1
- [typing] Make arguments to `__getitem__`/`__setitem__`/ etc. positional only. HOT 1
- Torch shared library: undefined symbol to CBLAS
- Investigate Error "f capturable=True, params and state_steps must be CUDA or XLA tensors" when nn module inlining enabled. HOT 3
- DISABLED test_some_output_requires_grad_input_doesnt (__main__.TestAOTAutograd) HOT 1
- Could not jit compile custom extension in dataparallel mode
- DISABLED test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.