I have a doubt about the rotary positional encoding part of the code.
def rotate_as_if_first(x, rotary_emb):
# x: [bs, num_attention_heads, seq_len, head_size]
# apply rotary as if all elements were first in the sequence
cos, sin = rotary_emb(x, x.shape[-2])
return rotate_one(x, cos, sin, torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device))
def rotate_as_if_first(x, rotary_emb, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
# apply rotary as if all elements were first in the sequence
cos, sin = rotary_emb(x, x.shape[-2])
return rotate_one(x, cos, sin, position_ids)