Dear Devansh,
Thanks so much for releasing the code. I have tried to reproduce the codes in machine translation. But I met some issues. Could you provide additional detailed procedures for preprocessing the data? I thought there existed such a procedure. Besides, I have met the following error when running https://github.com/devansh20la/LPF-SGD/tree/master/codes#training-2:
Traceback (most recent call last): File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/_adamtrain.py", line 174, in <module> inputs = translate(model, inputs, bpe_model, device) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/utils/train_utils.py", line 117, in translate outputs = greedy_decode(model, inputs, inputs_mask, max_len=num_tokens + 10, device=device).flatten() File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/utils/train_utils.py", line 91, in greedy_decode out = model.decode(ys, memory, tgt_mask) File "/mnt/sda/litao/LPF-SGD/codes/machine_translation/models/seq2seq.py", line 84, in decode return self.transformer.decoder(self.positional_encoding( File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 369, in forward output = mod(output, memory, tgt_mask=tgt_mask, File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 717, in forward x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 735, in _mha_block x = self.multihead_attn(x, mem, mem, File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1205, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 5279, in multi_head_attention_forward k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) RuntimeError: shape '[21, 8, 64]' is invalid for input of size 225792
Could you kindly help me with that?
Many thanks,
Tao