Git Product home page Git Product logo

pytorch-unflow's Introduction

pytorch-unflow

This is a personal reimplementation of UnFlow [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2].

Paper

For the original TensorFlow version of this work, please see: https://github.com/simonmeister/UnFlow
Other optical flow implementations from me: pytorch-pwc, pytorch-spynet, pytorch-liteflownet

setup

The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using pip install cupy or alternatively using one of the provided binary packages as outlined in the CuPy repository.

usage

To run it on your own pair of images, use the following command. You can choose between two models, please make sure to see their paper / the code for more details.

python run.py --model css --one ./images/one.png --two ./images/two.png --out ./out.flo

I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results identical to the implementation of the original authors in the examples that I tried. Please feel free to contribute to this repository by submitting issues and pull requests.

comparison

Comparison

license

As stated in the licensing terms of the authors of the paper, the models subject to an MIT license. Please make sure to further consult their licensing terms.

references

[1]  @inproceedings{Meister_AAAI_2018,
         author = {Simon Meister and Junhwa Hur and Stefan Roth},
         title = {{UnFlow}: Unsupervised Learning of Optical Flow with a Bidirectional Census Loss},
         booktitle = {AAAI},
         year = {2018}
     }
[2]  @misc{pytorch-unflow,
         author = {Simon Niklaus},
         title = {A Reimplementation of {UnFlow} Using {PyTorch}},
         year = {2018},
         howpublished = {\url{https://github.com/sniklaus/pytorch-unflow}}
    }

pytorch-unflow's People

Contributors

sniklaus avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-unflow's Issues

Different results between your inference and the tensorflow version

Hi,

Many thanks for the code. I used the original tensorflow code on my test set, but I got a very different result with your implementation. And your result shows a better performance. I got a few questions,
1). Are you using CSS architecture checkpoint for the inference instead of any fine-tuning version?
2). Have you tested the result from custom images.(rather than images from KITTI, SYTHIA).

Cheers,

How to train the model

Hi,thanks to your job.However,what should I do if I want to train the model myself instead of just using the pre-trained model ?

Model downloading too slow

I have used the bash script to get your PyTorch model, but it's too slow to continue my work, is there any way to get the model quickly, thanks!

loss function

Hi, I read this code but find no loss function. Are you tried to train the network with this project?

Training Code

Could you please provide the training code, especially the cesus loss part, thanks.

how to compute End-to-end point error?

Hi,
Thanks for the code! I want to evaluate optical flow quantitatively by end-to-end point errors (AEPE) and percentage of erroneous pixels (FI). I didn't find any reference code. Do you know how to compute it?

What's the aim of "* 20.0" after two upscale operation in the class "Upconv"?

class Upconv(torch.nn.Module):
    def forward(self, tenOne, tenTwo, objInput):
        objOutput = {}

        tenInput = objInput['conv6']
        objOutput['flow6'] = self.netSixOut(tenInput)
        tenInput = torch.cat([objInput['conv5'], self.netFivNext(tenInput), self.netSixUp(objOutput['flow6'])],
                             1)
        objOutput['flow5'] = self.netFivOut(tenInput)
        tenInput = torch.cat([objInput['conv4'], self.netFouNext(tenInput), self.netFivUp(objOutput['flow5'])],
                             1)
        objOutput['flow4'] = self.netFouOut(tenInput)
        tenInput = torch.cat([objInput['conv3'], self.netThrNext(tenInput), self.netFouUp(objOutput['flow4'])],
                             1)
        objOutput['flow3'] = self.netThrOut(tenInput)
        tenInput = torch.cat([objInput['conv2'], self.netTwoNext(tenInput), self.netThrUp(objOutput['flow3'])],
                             1)
        objOutput['flow2'] = self.netTwoOut(tenInput)

        return self.netUpscale(self.netUpscale(objOutput['flow2'])) * 20.0         # * 20.0 here

What is the aim of this multiplication? My guess is to enlarge the output of the network, because the network should output the absolute size of the offset. Is my guess right?

cupy problem

Hi,

Thanks for your work. When i tried to evaluate, I got a problem.

/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/cupy/cuda/compiler.py:461: UserWarning: cupy.cuda.compile_with_cache has been deprecated in CuPy v10, and will be removed in the future. Use cupy.RawModule or cupy.RawKernel instead.
warnings.warn(
Traceback (most recent call last):
File "run.py", line 352, in
tenOutput = estimate(tenOne, tenTwo)
File "run.py", line 338, in estimate
tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "run.py", line 302, in forward
tenFlow = netFlownet(tenOne, tenTwo, tenFlow)
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "run.py", line 200, in forward
tenCorr = self.netCorrelation(objOutput['conv3'], tenOther)
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "./correlation/correlation.py", line 394, in forward
return _FunctionCorrelation.apply(tenOne, tenTwo)
File "./correlation/correlation.py", line 310, in forward
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
File "cupy/_util.pyx", line 67, in cupy._util.memoize.decorator.ret
File "./correlation/correlation.py", line 274, in cupy_launch
return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/cupy/cuda/compiler.py", line 465, in compile_with_cache
return _compile_module_with_cache(*args, **kwargs)
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/cupy/cuda/compiler.py", line 493, in _compile_module_with_cache
return _compile_with_cache_cuda(
File "/data/public/yiping/anaconda3/envs/ts/lib/python3.8/site-packages/cupy/cuda/compiler.py", line 562, in _compile_with_cache_cuda
mod.load(cubin)
File "cupy/cuda/function.pyx", line 264, in cupy.cuda.function.Module.load
File "cupy/cuda/function.pyx", line 266, in cupy.cuda.function.Module.load
File "cupy_backends/cuda/api/driver.pyx", line 210, in cupy_backends.cuda.api.driver.moduleLoadData
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Exception ignored in: 'cupy.cuda.function.Module.dealloc'
Traceback (most recent call last):
File "cupy_backends/cuda/api/driver.pyx", line 217, in cupy_backends.cuda.api.driver.moduleUnload
File "cupy_backends/cuda/api/driver.pyx", line 60, in cupy_backends.cuda.api.driver.check_status
cupy_backends.cuda.api.driver.CUDADriverError: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

My environment is:

Name Version Build Channel

_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
blas 1.0 mkl
brotlipy 0.7.0 py38h27cfd23_1003
bzip2 1.0.8 h7b6447c_0
ca-certificates 2022.12.7 ha878542_0 conda-forge
certifi 2022.12.7 pyhd8ed1ab_0 conda-forge
cffi 1.15.0 py38h7f8727e_0
charset-normalizer 2.0.4 pyhd3eb1b0_0
cryptography 38.0.1 py38h9ce1e76_0
cudatoolkit 11.3.1 h2bc3f7f_2
cupy 11.4.0 py38h405e1b6_0 conda-forge
fastrlock 0.8.1 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
flit-core 3.6.0 pyhd3eb1b0_0
freetype 2.12.1 h4a9f257_0
giflib 5.2.1 h7b6447c_0
gmp 6.2.1 h295c915_3
gnutls 3.6.15 he1e5248_0
idna 3.4 py38h06a4308_0
intel-openmp 2021.4.0 h06a4308_3561
jpeg 9e h7f8727e_0
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
lerc 3.0 h295c915_0
libdeflate 1.8 h7f8727e_5
libedit 3.1.20221030 h5eee18b_0
libffi 3.2.1 hf484d3e_1007
libgcc-ng 12.2.0 h65d4601_19 conda-forge
libiconv 1.16 h7f8727e_2
libidn2 2.3.2 h7f8727e_0
libpng 1.6.37 hbc83047_0
libstdcxx-ng 12.2.0 h46fd767_19 conda-forge
libtasn1 4.16.0 h27cfd23_0
libtiff 4.4.0 hecacb30_2
libunistring 0.9.10 h27cfd23_0
libwebp 1.2.4 h11a3e52_0
libwebp-base 1.2.4 h5eee18b_0
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.4 h6a678d5_0
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py38h7f8727e_0
mkl_fft 1.3.1 py38hd3c417c_0
mkl_random 1.2.2 py38h51133e4_0
ncurses 6.3 h5eee18b_3
nettle 3.7.3 hbbd107a_1
numpy 1.23.5 py38h14f4228_0
numpy-base 1.23.5 py38h31eccc5_0
openh264 2.1.1 h4ff587b_0
openssl 1.1.1s h0b41bf4_1 conda-forge
pillow 9.3.0 py38hace64e9_1
pip 22.3.1 py38h06a4308_0
pycparser 2.21 pyhd3eb1b0_0
pyopenssl 22.0.0 pyhd3eb1b0_0
pysocks 1.7.1 py38h06a4308_0
python 3.8.0 h0371630_2
python_abi 3.8 2_cp38 conda-forge
pytorch 1.12.1 py3.8_cuda11.3_cudnn8.3.2_0 pytorch
pytorch-mutex 1.0 cuda pytorch
readline 7.0 h7b6447c_5
requests 2.28.1 py38h06a4308_0
setuptools 65.5.0 py38h06a4308_0
six 1.16.0 pyhd3eb1b0_1
sqlite 3.33.0 h62c20be_0
tk 8.6.12 h1ccaba5_0
torchaudio 0.12.1 py38_cu113 pytorch
torchvision 0.13.1 py38_cu113 pytorch
typing_extensions 4.4.0 py38h06a4308_0
urllib3 1.26.13 py38h06a4308_0
wheel 0.37.1 pyhd3eb1b0_0
xz 5.2.8 h5eee18b_0
zlib 1.2.13 h5eee18b_0
zstd 1.5.2 ha4553b6_0

Thank you very much in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.