Comments (11)
Currently, depth-wise convolution is not implemented.
I may be able to get to this in the near future. But feel free to open a PR if you'd like to work on it, too!
from fft-conv-pytorch.
Thanks for your kind reply, I will try on it too. Btw, I found the fft conv equals to the original conv only when the input size is even. When the number of input size is odd, then the difference of them will not be zero. Is there any problem or the setting of padding should be different? Code below is used for comparison.
def conv2d_pyt(input, weight):
pad_y = (weight.size(2) - 1) // 2
pad_x = (weight.size(3) - 1) // 2
fcg = f.conv2d(input, weight, bias=None, padding=(pad_y, pad_x))
return fcg
if name == 'main':
# calculate f*g
input = torch.randn(2, 3, 12, 8)
weight = torch.randn(64, 3, 3, 3)
fcg_pyt = conv2d_pyt(input, weight)
conv2d_fft = FFTConv2d(3, 64, 3, 1, bias=False)
conv2d_fft.weight = torch.nn.Parameter(weight)
fcg_fft = conv2d_fft(input)
avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
print('Average difference:', avg_diff)
from fft-conv-pytorch.
Could you elaborate a bit? When I run your code sample, I see:
Average difference: 6.670038601441775e-07
which is about as accurate as I would expect. Is this (roughly) what you're seeing as well?
Thanks for the feedback!
from fft-conv-pytorch.
If you modify the height or width of input to odd number, like input = torch.randn(2, 3, 12, 7), then the average difference will not be accurate ?
from fft-conv-pytorch.
Ah, ok I see it now. I was changing the input size along dimension 2, not the last dimension. Interesting that changing to input = torch.randn(2, 3, 11, 8)
does not affect the accuracy, but input = torch.randn(2, 3, 12, 7)
does.
I think this is caused by torch.fft.rfftn
, which computes a one-sided FFT by default. (The Fourier transformed Tensor always has odd-numbered length on the final dimension.) I'll have to look more closely into this. Will keep you updated.
from fft-conv-pytorch.
Yeah, you are correct. The one-sided FFT results in the inaccurate output. Let us have a try.
from fft-conv-pytorch.
I believe I fixed it. My testing probably isn't the most thorough, but your example from above works now. I tried similar things for 1D and 3D cases. (Now included in the benchmark.py
script.)
Thanks again for pointing that out! I'll try to come back around to depth-wise convolution soon.
from fft-conv-pytorch.
OK, I will check it later. Btw, I have implemented the depth-wise convolution in FFT, which simply sets the input channel of the weight to 1 (like weight = torch.randn(4, 1, 3, 3)) and replaces the function complex_matmul by:
def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
"""Multiplies two complex-valued tensors."""
b = b.permute(1, 0, 2, 3)
real = a.real * b.real - a.imag * b.imag
imag = a.imag * b.real + a.real * b.imag
c = torch.zeros(real.shape, dtype=torch.complex64)
c.real, c.imag = real, imag
return c
It works well and you can have a try.
from fft-conv-pytorch.
Glad to hear you have it working. But I believe depth-wise convolution is usually implemented by setting groups = in_channels
:
conv = nn.Conv2d(64, 64, 3, padding=1, groups=64)
I'd like to stick to PyTorch conventions. Is it possible to efficiently implement this for FFT using groups? It took a bit more time, but I managed to implement it using groups. The complex_matmul
function is harder to understand now, but now it matches the behavior for grouped convolutions.
You can get depth-wise separable convolution like this:
conv = FFTConv2d(64, 64, 3, padding=1, groups=64)
Similarly, you can use the convolution function directly:
y = fft_conv(
signal=torch.randn(1, 64, 128, 128),
kernel=torch.randn(64, 1, 3, 3),
padding=1,
groups=64,
)
from fft-conv-pytorch.
It looks great now! I only implemented a special case of depth-wise convolution, in which groups equal to the input channels. Thanks for your great work and I will close the issue now.
from fft-conv-pytorch.
Appreciate all your feedback!
from fft-conv-pytorch.
Related Issues (13)
- can't work on GPU? HOT 1
- FFTConvTranspose
- Stride HOT 1
- CUDA out of memory with complex_matmul HOT 5
- Complex value support?
- Using fft-conv hurts convergence HOT 2
- License HOT 8
- How to achieve overlap and add/save
- Propagation of error becomes large very fast HOT 1
- Autograd for complex matrix multiplication in Pytorch ? HOT 3
- bug HOT 4
- in_channels must be divisible by groups
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fft-conv-pytorch.