Git Product home page Git Product logo

Comments (11)

fkodom avatar fkodom commented on July 21, 2024

Currently, depth-wise convolution is not implemented.

I may be able to get to this in the near future. But feel free to open a PR if you'd like to work on it, too!

from fft-conv-pytorch.

vaesl avatar vaesl commented on July 21, 2024

Thanks for your kind reply, I will try on it too. Btw, I found the fft conv equals to the original conv only when the input size is even. When the number of input size is odd, then the difference of them will not be zero. Is there any problem or the setting of padding should be different? Code below is used for comparison.

def conv2d_pyt(input, weight):
pad_y = (weight.size(2) - 1) // 2
pad_x = (weight.size(3) - 1) // 2
fcg = f.conv2d(input, weight, bias=None, padding=(pad_y, pad_x))
return fcg

if name == 'main':

# calculate f*g
input = torch.randn(2, 3, 12, 8)
weight = torch.randn(64, 3, 3, 3)

fcg_pyt = conv2d_pyt(input, weight)
conv2d_fft = FFTConv2d(3, 64, 3, 1, bias=False)
conv2d_fft.weight = torch.nn.Parameter(weight)
fcg_fft = conv2d_fft(input)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)

from fft-conv-pytorch.

fkodom avatar fkodom commented on July 21, 2024

Could you elaborate a bit? When I run your code sample, I see:
Average difference: 6.670038601441775e-07
which is about as accurate as I would expect. Is this (roughly) what you're seeing as well?

Thanks for the feedback!

from fft-conv-pytorch.

vaesl avatar vaesl commented on July 21, 2024

If you modify the height or width of input to odd number, like input = torch.randn(2, 3, 12, 7), then the average difference will not be accurate ?

from fft-conv-pytorch.

fkodom avatar fkodom commented on July 21, 2024

Ah, ok I see it now. I was changing the input size along dimension 2, not the last dimension. Interesting that changing to input = torch.randn(2, 3, 11, 8) does not affect the accuracy, but input = torch.randn(2, 3, 12, 7) does.

I think this is caused by torch.fft.rfftn, which computes a one-sided FFT by default. (The Fourier transformed Tensor always has odd-numbered length on the final dimension.) I'll have to look more closely into this. Will keep you updated.

from fft-conv-pytorch.

vaesl avatar vaesl commented on July 21, 2024

Yeah, you are correct. The one-sided FFT results in the inaccurate output. Let us have a try.

from fft-conv-pytorch.

fkodom avatar fkodom commented on July 21, 2024

I believe I fixed it. My testing probably isn't the most thorough, but your example from above works now. I tried similar things for 1D and 3D cases. (Now included in the benchmark.py script.)

Thanks again for pointing that out! I'll try to come back around to depth-wise convolution soon.

from fft-conv-pytorch.

vaesl avatar vaesl commented on July 21, 2024

OK, I will check it later. Btw, I have implemented the depth-wise convolution in FFT, which simply sets the input channel of the weight to 1 (like weight = torch.randn(4, 1, 3, 3)) and replaces the function complex_matmul by:

def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
"""Multiplies two complex-valued tensors."""
b = b.permute(1, 0, 2, 3)
real = a.real * b.real - a.imag * b.imag
imag = a.imag * b.real + a.real * b.imag
c = torch.zeros(real.shape, dtype=torch.complex64)
c.real, c.imag = real, imag
return c

It works well and you can have a try.

from fft-conv-pytorch.

fkodom avatar fkodom commented on July 21, 2024

Glad to hear you have it working. But I believe depth-wise convolution is usually implemented by setting groups = in_channels:

conv = nn.Conv2d(64, 64, 3, padding=1, groups=64)

I'd like to stick to PyTorch conventions. Is it possible to efficiently implement this for FFT using groups? It took a bit more time, but I managed to implement it using groups. The complex_matmul function is harder to understand now, but now it matches the behavior for grouped convolutions.

You can get depth-wise separable convolution like this:

conv = FFTConv2d(64, 64, 3, padding=1, groups=64)

Similarly, you can use the convolution function directly:

y = fft_conv(
    signal=torch.randn(1, 64, 128, 128),
    kernel=torch.randn(64, 1, 3, 3),
    padding=1,
    groups=64,
)

from fft-conv-pytorch.

vaesl avatar vaesl commented on July 21, 2024

It looks great now! I only implemented a special case of depth-wise convolution, in which groups equal to the input channels. Thanks for your great work and I will close the issue now.

from fft-conv-pytorch.

fkodom avatar fkodom commented on July 21, 2024

Appreciate all your feedback!

from fft-conv-pytorch.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.