Comments (6)
Related issue: #116333
from pytorch.
This issue is a little different, we dont currently do the last dim padding automatically, maybe we should.. however this should work for you
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
query = torch.tensor([[[[1, 2]]]], dtype=torch.float32)
query = query.transpose(-1, -2).to("cuda")
key = torch.tensor([[[[1]]]], dtype=torch.float32).to("cuda")
value = torch.tensor([[[[1]]]], dtype=torch.float32).to("cuda")
with sdpa_kernel(SDPBackend.MATH):
out_unpadded = scaled_dot_product_attention(query, key, value) # Works fine
q_padding_amount = 4 - query.size(-1)
kv_padding_amount = 4 - key.size(-1)
query_padded = torch.nn.functional.pad(query, (0, q_padding_amount))
key_padded = torch.nn.functional.pad(key, (0, kv_padding_amount))
value_padded = torch.nn.functional.pad(value, (0, kv_padding_amount))
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
out_padded = scaled_dot_product_attention(query_padded, key_padded, value_padded) # Fails, stacktrace below
out_sliced = out_unpadded[..., : query.size(-1)]
torch.testing.assert_close(out_unpadded, out_sliced)
from pytorch.
Thanks @drisspg!
Sorry, I may have mixed something up while reading the other issue. Here's the actual problem I am facing. It relates to tensors not having stride 1 in the last dimension which does not get fixed with .contiguous()
.
import torch
query = torch.ones(2, 16, 1, 64, device="cuda:0")
key = torch.ones(2, 16, 1, 64, device="cuda:0")
values = torch.ones(2, 16, 1, 64, device="cuda:0")
attn_mask = torch.ones(1, 1, 16, device="cuda:0")
attn_mask = attn_mask.permute([2, 0, 1]).unsqueeze(0).contiguous()
with torch.backends.cuda.sdp_kernel(
enable_math=False, enable_flash=False, enable_mem_efficient=True
): # works when you set enable_math=True
torch.nn.functional.scaled_dot_product_attention(
query,
key,
values,
attn_mask=attn_mask,
)
I am facing this issue when adding SDPA support to T5 models in transformers
: huggingface/transformers#30375 (comment)
from pytorch.
@drisspg any thoughts on the issue above?
from pytorch.
A workaround:
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
query = torch.ones(2, 16, 1, 64, device="cuda:0")
key = torch.ones(2, 16, 1, 64, device="cuda:0")
values = torch.ones(2, 16, 1, 64, device="cuda:0")
attn_mask = torch.ones(1, 1, 16, device="cuda:0")
attn_mask = attn_mask.permute([2, 0, 1]).unsqueeze(0)
if attn_mask.stride(-1) != 1:
attn_mask = torch.empty_like(attn_mask, memory_format=torch.contiguous_format).copy_(attn_mask)
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION) as context: # works when you set enable_math=True
torch.nn.functional.scaled_dot_product_attention(
query,
key,
values,
attn_mask=attn_mask,
)
The reason why we dont enable this by default is because this can cause an unexpected memory spike depending on the size of the attn mask, so in this case we came to the conclusion that it is better for users to do this explicitly
from pytorch.
Got it, thanks!
from pytorch.
Related Issues (20)
- UNSTABLE trunk / before-test / llm-retrieval HOT 4
- aot_compile unintuitive error with make_fx graphs HOT 3
- Torch Distributed Pipelinging API datatype mismatch HOT 5
- hasattr tracing for torch tensor is unsupported
- self-driving car
- nccl deadlock HOT 1
- Failed to compute shorthash for libnvrtc.so
- `state_dict` inconsistency between ExportedProgram and jit.trace when using TSConverter
- `@pytorchbot merge -i` not ignoring currently failing checks HOT 1
- `avg_pool2d_backward` IMAs on tensors larger than 2^30 HOT 3
- <SymmetricMemory> Expected group_info_map.find(group_name) == group_info_map.end() to be true, but got false. HOT 1
- [dynamic-shapes][dynamo][recompilation] Bug in automatic dynamic when input is a view in one invocation and not in other HOT 2
- 3D matrix support in _scaled_mm
- Drop Python 3.8 and 3.9 support (following NEP 29) HOT 1
- SymmetricMemory Got Error When Creating New Process Group HOT 1
- Stacktrace getting lost after lowering to Aten/Prim IR in the AOTAutograd pass
- `device` is displayed differently with or without `print()` on the IPython based editors like Jupyter notebook or Colab HOT 1
- Crash when testing Libtorch example HOT 1
- torch.lt with out parameter does not return the correct shape
- Output does not match HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.