Comments (9)
I don't have specific numbers for you but I do know that raw GFLOPs for an operation are not the only consideration in overall computational speed. Things like code branches, non-sequential memory reads ... etc can greatly slow down a program running on a GPU even when overall GFLOPs are less.
from pyscatwave.
Also, you want to double-verify that you are using the .cuda() variety of the scattering transform.
from pyscatwave.
@lolz0r thanks for the reply.
I am using cuda()
and also transferring all the tensors to the GPU before I start comparing speeds of both networks.
The network is pretty simple, a feedforward network with 30+ convolutions.
I was expecting that by replacing the feature extractor (I'm using inception v3) by the scattering network I'd see some speed up, I was not expecting a 200x slowdown.
One thing I noticed is that there are multiple copies of the S
transform in the scattering
code. Is there a possibikity that these copies are moving data between cpu and gpu?
from pyscatwave.
Just for fun what are you M, N and J values set to ?
from pyscatwave.
M=224, N=224, J=3
I've noticed that periodize
and modulus
take most of the time since their first pass is not already cached. If I run the scatter net twice, the second time is 1.5x faster
from pyscatwave.
Well, maybe you could do a proper timing, but I'm not shocked if you claim that the scattering takes about 1s for a batch of 256 RGB images of large size. The implementation is definitely not optimal, still faster than CPU implementations, the copies are due to the use of buffers. We're open to ideas to speed-up the software, while keeping the memory reasonable. In your particular case, the padding size could be a bit large. Depending on the application, it is possible to obtain a speed-up, w.r.t. your current pipeline.
I'm a bit surprised that the modulus kernel is slow, but not for periodize's one. I think that any expert in CUDA could drastically optimize our software, and I'd be pretty happy and open to any suggestions 👍
from pyscatwave.
@edouardoyallon Thanks for your reply.
Actually my test is using only a single RGB input size is (1,3,224,224)
and after a minor optimization it went from 2s to 0.5s. I'm using a gtx1080.
Also is there a reason why most fft
and cdgmm
calls are not inplace=True
?
I am understanding the code more and will see if I can manage something, will submit PR in that case.
from pyscatwave.
So there is definitely an issue in your testing procedure, as a batch of size (256,3,224,224) should take about 1s on GPU.(this timing sounds more like for CPU?)
In this version of the code, we removed explicitely the buffers to let pytorch decides the allocations(such that it is optimal). When the fft is not in place, it means that the result will be used later. Imho, the code to optimize is there: https://github.com/edouardoyallon/pyscatwave/blob/master/scatwave/utils.py (each subroutine) Furthermore, if you don't use large batch, then, your timing could be a bit screw up.
from pyscatwave.
True. I just ran a quick test on with size (64,3,224,224) and scaterring runs at 0.33 seconds on the second pass, first call takes 1.2s probably because the modulus kernel is not yet cached at that point.
Still this is slower than my whole original network which has 22GFLOPs, and yes I'm running everything on gpu
.
Thanks for the help! :)
from pyscatwave.
Related Issues (20)
- Inverse of Scattering Transform HOT 3
- run pyscatwave without cuda HOT 6
- How to calculate K, the output dimension of the output of scattering? HOT 1
- Is it very slow for big images like 4000*2000?
- Is there an implementation of 1-D Scattering Transform? HOT 2
- Error on AWS pytorch2.7 AIM
- why do you want to avoid fft normalization? HOT 1
- New version with 1D scaterring HOT 1
- Reconstructing the input image HOT 4
- Fantastic work! HOT 1
- MNIST example is incorrect HOT 2
- I have a question about reconstruction_exp.py HOT 1
- Implementation of pyscatwave using DL library primitives HOT 4
- About the problems of scattering network
- cublasNotInitialized HOT 13
- Tensorflow re-implementation HOT 10
- which part is gaussian smoothing? HOT 3
- Weird unpredictable behaviour HOT 3
- Is calculating a derivative possible? HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pyscatwave.