Git Product home page Git Product logo

Comments (6)

abdulfatir avatar abdulfatir commented on July 2, 2024

Related issue: #116333

from pytorch.

drisspg avatar drisspg commented on July 2, 2024

This issue is a little different, we dont currently do the last dim padding automatically, maybe we should.. however this should work for you

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend


query = torch.tensor([[[[1, 2]]]], dtype=torch.float32)
query = query.transpose(-1, -2).to("cuda")
key = torch.tensor([[[[1]]]], dtype=torch.float32).to("cuda")
value = torch.tensor([[[[1]]]], dtype=torch.float32).to("cuda")


with sdpa_kernel(SDPBackend.MATH):
    out_unpadded = scaled_dot_product_attention(query, key, value)  # Works fine

q_padding_amount = 4 - query.size(-1)
kv_padding_amount = 4 - key.size(-1)
query_padded = torch.nn.functional.pad(query, (0, q_padding_amount))
key_padded = torch.nn.functional.pad(key, (0, kv_padding_amount))
value_padded = torch.nn.functional.pad(value, (0, kv_padding_amount))

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    out_padded = scaled_dot_product_attention(query_padded, key_padded, value_padded)  # Fails, stacktrace below

out_sliced = out_unpadded[..., : query.size(-1)]

torch.testing.assert_close(out_unpadded, out_sliced)

from pytorch.

abdulfatir avatar abdulfatir commented on July 2, 2024

Thanks @drisspg!

Sorry, I may have mixed something up while reading the other issue. Here's the actual problem I am facing. It relates to tensors not having stride 1 in the last dimension which does not get fixed with .contiguous().

import torch

query = torch.ones(2, 16, 1, 64, device="cuda:0")
key = torch.ones(2, 16, 1, 64, device="cuda:0")
values = torch.ones(2, 16, 1, 64, device="cuda:0")
attn_mask = torch.ones(1, 1, 16, device="cuda:0")
attn_mask = attn_mask.permute([2, 0, 1]).unsqueeze(0).contiguous()

with torch.backends.cuda.sdp_kernel(
            enable_math=False, enable_flash=False, enable_mem_efficient=True
        ): # works when you set enable_math=True
    torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        values,
        attn_mask=attn_mask,
    )

I am facing this issue when adding SDPA support to T5 models in transformers: huggingface/transformers#30375 (comment)

from pytorch.

abdulfatir avatar abdulfatir commented on July 2, 2024

@drisspg any thoughts on the issue above?

from pytorch.

drisspg avatar drisspg commented on July 2, 2024

A workaround:

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel

query = torch.ones(2, 16, 1, 64, device="cuda:0")
key = torch.ones(2, 16, 1, 64, device="cuda:0")
values = torch.ones(2, 16, 1, 64, device="cuda:0")
attn_mask = torch.ones(1, 1, 16, device="cuda:0")
attn_mask = attn_mask.permute([2, 0, 1]).unsqueeze(0)

if attn_mask.stride(-1) != 1:
    attn_mask = torch.empty_like(attn_mask, memory_format=torch.contiguous_format).copy_(attn_mask)
    
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION) as context: # works when you set enable_math=True
    torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        values,
        attn_mask=attn_mask,
    )

The reason why we dont enable this by default is because this can cause an unexpected memory spike depending on the size of the attn mask, so in this case we came to the conclusion that it is better for users to do this explicitly

from pytorch.

abdulfatir avatar abdulfatir commented on July 2, 2024

Got it, thanks!

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.