Git Product home page Git Product logo

Comments (17)

Chillee avatar Chillee commented on May 13, 2024 1

@manman-ren If there are any ways to rewrite the kernel to improve perf we'd also be interested in that.

from pytorch.

htyu avatar htyu commented on May 13, 2024 1

Looking at the NCU profiles, I'm seeing thread occupancy metrics are very close for the two kernels, so I guess register pressure doesn’t affect much here.

With the masking we have extra data movement, which are all in registers. But I’m seeing memory throughput is different.

I'm also seeing instruction execution counts are quite different:

good:
image

bad:
image

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024 1

By disabling LSR and moving the arange instructions

offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)

we can get to 0.007877.

bad.py --> bad2.py --> bad4.py + disable LSR
0.009078 0.008732 0.007877
--> -13%

Actually if we only disable LSR with the actual masks, we get
FA2: 6.420323848724365
templated attention: 7.937378406524658
With LSR
FA2: 6.392755508422852
templated attention: 9.197888374328613

Looks like LSR plays a significant role when the register usage is high.

Patch to disable LSR:

diff --git a/python/src/llvm.cc b/python/src/llvm.cc
index 3b7f8fa30..a21ee2860 100644
--- a/python/src/llvm.cc
+++ b/python/src/llvm.cc
@@ -55,6 +55,14 @@ std::string translateLLVMIRToASM(llvm::Module &module,
       *optPtr = true;
     }
   }
+  {
+      auto optIt = options.find("disable-lsr");
+      if (optIt != options.end()) {
+        auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
+        *optPtr = true;
+        llvm::errs() << "disable LSR ----------------\n";
+      }
+  }
 
   // inline everything
   for (llvm::Function &f : module.functions())

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

We will take a look!

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

Which version of pytorch should I use?
I tried master, but hit
return torch.where(patch_bidirectional | causal_mask, score, -float("inf"))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: where() received an invalid combination of arguments - got (bool, Tensor, float), but expected one of:

from pytorch.

Chillee avatar Chillee commented on May 13, 2024

@manman-ren sorry my script was a little bit wrong - tried to manually patch in some comments but forgot to actually comment them. Updated the script

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

Can repro this on A100:
FA2: 6.392755508422852
templated attention: 9.197888374328613

With patch that sets causal_mask to True:
FA2: 6.416097640991211
templated attention: 7.2350616455078125

I don't see any obvious issue with the optimizations. For the two ttgir attached in the description, the only thing I notice is that a mulf is done outside of the loop for the good case, and it is inside the loop for the bad case. But when I repro locally, this mulf doesn't exist in both versions. The other differences are just calculating mask and using mask.

When looking at ncu profiles, the bad case has more stalls.
warp-stall

The register usage also increases from 224 to 255. But the occupancy is similar. So this looks like an issue with register pressure and scheduling? By scheduling, I mean possibilities of moving the mask calculation so it will not be blocking.

Comparing the stalls for good vs. bad
good1
bad1

It seems this kernel doesn't have autotuning, it performs 2 dot operations with [m,n,k] of [128,64,64].
@htyu We can take a look at the profiles together and see if you can spot something that I may have missed :]

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

Let me see if I can modify the source to make it faster.

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

We can try to move the mask calculation ahead:

diff bad.py bad2.py 
--- bad.py	2024-04-18 16:36:19.578471966 -0700
+++ bad2.py	2024-04-18 16:35:14.303964448 -0700
@@ -131,8 +131,18 @@
     # loop over k, v and update accumulator
     lo = 0
     hi = N_CTX
+    tmp0 = tl.full([1], 1024, tl.int64)
+    tmp1 = (offs_m[:, None]) <= tmp0 # BLOCK_M, 1
+
     for start_n in range(lo, hi, BLOCK_N):
         start_n = tl.multiple_of(start_n, BLOCK_N)
+        tmp2 = (start_n + offs_n[None, :]) <= tmp0 # 1, BLOCK_N
+        tmp3 = tmp1 & tmp2
+        tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
+        tmp5 = tmp3 | tmp4
+        tmp6 = float("-inf")
+        tmp7 = tmp6.to(tl.float32)
+
         # -- load k, v --
         k = tl.load(K_block_ptr)
         v = tl.load(V_block_ptr)
@@ -141,14 +151,14 @@
         qk += tl.dot(q, k.to(MATMUL_PRECISION))
         # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
 
-        tmp0 = tl.full([1], 1024, tl.int64)
-        tmp1 = (offs_m[:, None]) <= tmp0
-        tmp2 = (start_n + offs_n[None, :]) <= tmp0
-        tmp3 = tmp1 & tmp2
-        tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
-        tmp5 = tmp3 | tmp4
-        tmp6 = float("-inf")
-        tmp7 = tmp6.to(tl.float32)
+        #tmp0 = tl.full([1], 1024, tl.int64)
+        #tmp1 = (offs_m[:, None]) <= tmp0
+        #tmp2 = (start_n + offs_n[None, :]) <= tmp0 # 1, BLOCK_N
+        #tmp3 = tmp1 & tmp2
+        #tmp4 = (offs_m[:, None]) >= (start_n + offs_n[None, :])
+        #tmp5 = tmp3 | tmp4
+        #tmp6 = float("-inf")
+        #tmp7 = tmp6.to(tl.float32)
         tmp8 = tl.where(tmp5, (qk), tmp7)
         qk = tmp8

Seems to be a bit faster
CUDA_VISIBLE_DEVICES=6 python3 bad2.py
0.008732
CUDA_VISIBLE_DEVICES=6 python3 bad.py
0.009078

We can also try to rematerialize address calculations to see if it affects register pressure. You are already using tl.make_block_ptr, which has some weird interactions with register pressure (in most cases, good interaction).

from pytorch.

htyu avatar htyu commented on May 13, 2024

By disabling LSR and moving the arange instructions

offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)

we can get to 0.007877.

bad.py --> bad2.py --> bad4.py + disable LSR 0.009078 0.008732 0.007877 --> -13%

Actually if we only disable LSR with the actual masks, we get FA2: 6.420323848724365 templated attention: 7.937378406524658 With LSR FA2: 6.392755508422852 templated attention: 9.197888374328613

Looks like LSR plays a significant role when the register usage is high.

Patch to disable LSR:

diff --git a/python/src/llvm.cc b/python/src/llvm.cc
index 3b7f8fa30..a21ee2860 100644
--- a/python/src/llvm.cc
+++ b/python/src/llvm.cc
@@ -55,6 +55,14 @@ std::string translateLLVMIRToASM(llvm::Module &module,
       *optPtr = true;
     }
   }
+  {
+      auto optIt = options.find("disable-lsr");
+      if (optIt != options.end()) {
+        auto optPtr = static_cast<llvm::cl::opt<bool> *>(optIt->second);
+        *optPtr = true;
+        llvm::errs() << "disable LSR ----------------\n";
+      }
+  }
 
   // inline everything
   for (llvm::Function &f : module.functions())

Nice find! What does the PTX code and the profile look like after disabling LSR? Wondering how it affects code quality.

from pytorch.

Chillee avatar Chillee commented on May 13, 2024

@manman-ren btw, I merged this PR today: #124356 which improves the baseline numbers. There is still a significant gap but this might change some analysis.

from pytorch.

manman-ren avatar manman-ren commented on May 13, 2024

I reran the tests (good vs. bad) after rebasing pytorch to latest
CUDA_VISIBLE_DEVICES=6 python bad.py
disable LSR ----------------
FA2: 6.40308141708374
templated attention: 7.860274791717529

do not disable LSR
FA2: 6.46448278427124
templated attention: 7.766917705535889

good.py (remove mask) do not disable LSR
FA2: 6.471131801605225
templated attention: 6.987882137298584

The gap is much smaller now: 6.987882137298584 vs 7.766917705535889
And the impact of disabling LSR went from -17% to +1%.

I will try to figure out why LSR sometimes has such a big negative perf impact.

About this change in the template

-        acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
+       acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)

it shouldn't change anything for bf16 or fp16, right? @htyu

from pytorch.

htyu avatar htyu commented on May 13, 2024

About this change in the template

-        acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
+       acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)

it shouldn't change anything for bf16 or fp16, right? @htyu

Embedding acc in dot operation, which is basically an FMA, loses precision actually. We can think of it as a fast-math mode.

from pytorch.

Chillee avatar Chillee commented on May 13, 2024

Embedding acc in dot operation, which is basically an FMA, loses precision actually.

Loses precision or gains precision? FMA usually gains precision, no?

from pytorch.

htyu avatar htyu commented on May 13, 2024

Embedding acc in dot operation, which is basically an FMA, loses precision actually.

Loses precision or gains precision? FMA usually gains precision, no?

IIUC, FMA loses precision compared to separate mul and add instructions.

from pytorch.

Chillee avatar Chillee commented on May 13, 2024

@htyu I don't think that's true. See https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation

Fused multiply–add can usually be relied on to give more accurate results.

It can certainly cause issues at times (see https://dev-discuss.pytorch.org/t/fmas-and-softmax-and-floating-point-considered-harmful/1953), but specifically, the instruction itself should lead to "increased precision".

from pytorch.

htyu avatar htyu commented on May 13, 2024

@htyu I don't think that's true. See https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation

Fused multiply–add can usually be relied on to give more accurate results.

It can certainly cause issues at times (see https://dev-discuss.pytorch.org/t/fmas-and-softmax-and-floating-point-considered-harmful/1953), but specifically, the instruction itself should lead to "increased precision".

You are right on FMA increasing accuracy in general, as it avoids one extra rounding.

It looks like also depending on hardware implementation. Found some code pointers related to fp8 dot:
https://github.com/openai/triton/blob/792315bb76211ae0dcef535f426d531ebc0fe2ff/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#L434

  • acc += tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION))
  • acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)

As for the codegen, without acc, there will be separate fadd instructions generated.

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.