Git Product home page Git Product logo

Comments (9)

DanFu09 avatar DanFu09 commented on August 23, 2024 1

from safari.

DanFu09 avatar DanFu09 commented on August 23, 2024 1

from safari.

Zymrael avatar Zymrael commented on August 23, 2024 1

afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the fftshift + centered DFT convention. You don't need fftshift in PyTorch code for the DFT result to be right.

from safari.

Zymrael avatar Zymrael commented on August 23, 2024 1

Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first seqlen elements, your convolution is no longer causal. Padding is just an artifact to turn a circular convolution (for which the FFTConv method holds) into a linear convolution (which is what we want to compute) - at the output, you need to select the first elements for the result to be correct.

from safari.

veritas9872 avatar veritas9872 commented on August 23, 2024 1

I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation!

from safari.

veritas9872 avatar veritas9872 commented on August 23, 2024

Thank you for the quick response!
I think that my question is slightly different.
The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence.
image

Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT.
While this may be canceled out, I was curious if this might affect the result.

This discussion may also be helpful. pytorch/pytorch#51022

from safari.

veritas9872 avatar veritas9872 commented on August 23, 2024

I have tested the function and I believe that this is indeed the issue.

The following code does indeed show that shifting is unnecessary for FFT in PyTorch.

Thank you for your help!

from scipy import signal
import torch
import numpy as np


@torch.inference_mode()
def test1():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
    f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
    print(np.allclose(c, f))  # True


@torch.inference_mode()
def test2():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
    f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
    print(np.allclose(c, f))  # False

from safari.

veritas9872 avatar veritas9872 commented on August 23, 2024

The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue.

from safari.

veritas9872 avatar veritas9872 commented on August 23, 2024

Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to scipy.signal.convolve(...,mode='same') may be more desirable.

The resulting code would be as follows.

seqlen = u.shape[-1]
fft_size = 2 * seqlen

 k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
 u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward')  # Explicit norm mode for better readability.
  
 if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 
 y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen]

from safari.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.