Git Product home page Git Product logo

torchprofile's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torchprofile's Issues

No handlers found:"aten::pow"

No handlers found: "aten::unsqueeze"
No handlers found: "aten::pow".
No handlers found: "aten::sign"
No handlers found: "aten::abs"
No handlers found: "aten::fft_fft"
No handlers found: "aten::real"
No handlers found: "aten::imag"

No handlers found: "aten::einsum"

No handlers found: "aten::unfold"
No handlers found: "aten::split"
No handlers found: "aten::einsum".
No handlers found: "aten::permute".

hope to support these ops, thanks!

No handlers found

Thank you for your great work!
I encountered this warning when using torchprofile. It seems like torchprofile cannot calculate the GFLOPS of torch.nn.functional.grid_sample:
UserWarning: No handlers found: "aten::grid_sampler". Skipped.

Unused operations seem to be included in FLOPS calculation

Hello,

I'm working with NsgaNetV2 code (https://github.com/mikelzc1990/nsganetv2); that codebase is for NAS, there's a supernetwork with shared weights, and subnetworks of it that are evaluated. Since the supernetwork is shared, subnetworks are defined dynamically. For each subnetwork, FLOPS count needs to be calculated.

I faced a problem with inconsistent FLOPS counts; I think I found the root cause: even when a layer is not used in the subnetwork, it is still counted towards the FLOPS count. This is problematic, but things get worse. To explain using an example, suppose I have 3 configs of the subnetworks:

config = {'ks': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 'e': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 'd': [2, 2, 2, 2, 2], 'r': 192}

config_big_long = {"ks": [7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], "e": [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], "d": [4, 4, 4, 4, 4], "r": 192}

config_small_long = {"ks": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "e": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "d": [4, 4, 4, 4, 4], "r": 192}

I want to count FLOPs for the subnetwork defined by config. That subnetwork is shallower than both config_big_long and config_small_long. This means that the parameters of the unused-by-config layers will depend on which of the other two configs was activated before.

At least that's what I observe:

> self.engine.set_active_subnet(ks=config_big_long['ks'], e=config_big_long['e'], d=config_big_long['d'])
> self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
> int(profile_macs(copy.deepcopy(self.engine.get_active_subnet(True)), inputs.cuda()))
130925344
> self.engine.set_active_subnet(ks=config_small_long['ks'], e=config_small_long['e'], d=config_small_long['d'])
> self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
> int(profile_macs(copy.deepcopy(self.engine.get_active_subnet(True)), inputs.cuda()))
89421520

This behaviour is strange, because from the code of your package it seems that dynamic tracing is used, so I'm not sure why this happens.

Handlers not found!

Hi, I was doing profiling on the recently released StarGAN-V2 ( https://github.com/clovaai/stargan-v2 ) pre-trained generator and style_encoder networks. However, certain operations used in these networks are missing from torch profile. Following are the missing handlers:

  • "aten::leaky_relu"
  • "aten::to"
  • "aten::detach"
  • "aten::floor"
  • "aten::upsample_nearest2d"
  • "aten::index"
  • "aten::to"
  • "aten::stack"

Please suggest where can I add these operations in profile.py to do an accurate profiling?

the result of macs is not accurate

I test a simple model as following:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    from torchsummary import summary
    from torchprofile import profile_macs
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5, padding=2, bias=False)
            self.bn1 = nn.BatchNorm2d(10)
            
            self.conv2 = nn.Conv2d(10, 10, kernel_size=3, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(10)
            
        def forward(self, x):
            residual = x
            x = F.relu(self.bn1(self.conv1(x)))
            x = F.relu(self.bn2(self.conv2(x)))
            x += residual
            return x
    
    if __name__=="__main__":
        model = Net().cuda()
        # count parameters
        summary(model, (1, 28, 28)) 
        # count FLOPs
        inputs = torch.randn(1, 1, 28, 28).cuda()
        flops = profile_macs(model, inputs)
        print(f"FLOPs : {flops}") 

And the results are

      ----------------------------------------------------------------
              Layer (type)               Output Shape         Param #
      ================================================================
                  Conv2d-1           [-1, 10, 28, 28]             250
             BatchNorm2d-2           [-1, 10, 28, 28]              20
                  Conv2d-3           [-1, 10, 28, 28]             900
             BatchNorm2d-4           [-1, 10, 28, 28]              20
                     Net-5           [-1, 10, 28, 28]               0
      ================================================================
      Total params: 1,190
      Trainable params: 1,190
      Non-trainable params: 0
      ----------------------------------------------------------------
      Input size (MB): 0.00
      Forward/backward pass size (MB): 0.30
      Params size (MB): 0.00
      Estimated Total Size (MB): 0.31
      ----------------------------------------------------------------
      
      FLOPs : 901600

Obviously, the results only count the macs of convolution layer by 10x28x28x5x5x1+10x28x28x3x3x10, but ignores the macs of batch normalization layer and shortcut addition.

Negative MACs

I'm getting an overflow encountered error and profile gives me a negative number of MACS. is this a torchprofile issue or an issue with my model?

[Feature Question]FLOPs computation support?

I was testing different tools for computing FLOPs and MACs for CNN models and came across this repo. It seems like the FLOPs computation is not supported yet.

Was wondering if it's in future dev plan and if I could help with it? Thanks!

No handlers found: "torchvision::roi_align". Skipped

Thank you for your great work!
I encountered this warning when using torchprofile. It seems like torchprofile cannot calculate the MACs of

UserWarning: No handlers found: "torchvision::roi_align". Skipped
UserWarning: No handlers found: "aten::sqrt". Skipped
UserWarning: No handlers found: "aten::log2". Skipped
UserWarning: No handlers found: "aten::reciprocal". Skipped.
UserWarning: No handlers found: "aten::scalarimplicit". Skipped.
UserWarning: No handlers found: "aten::max". Skipped.

By the way, apart from the MACs, Is there any way to calculate the model memory use?

profile_macs error at PyTorch 1.4

Hi, Thank you for the awesome tool!
profile_macs is not working at PyTorch 1.4
The key error message is

AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'

Thank you in advance!

Here is my entire error log:

Traceback (most recent call last):
  File "gen_summary.py", line 54, in <module>
    macs = profile_macs(model, inputs)
  File "/home/shkim/.conda/envs/torch1.4/lib/python3.6/site-packages/torchprofile/profile.py", line 12, in profile_macs
    graph = trace(model, args, kwargs)
  File "/home/shkim/.conda/envs/torch1.4/lib/python3.6/site-packages/torchprofile/utils/trace.py", line 17, in trace
    trace, _ = torch.jit.get_trace_graph(Flatten(model), args, kwargs)
AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'

my env setting as follows:

# packages in environment at **:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_tflow_select             2.1.0                       gpu  
absl-py                   0.8.1                    py36_0  
astor                     0.8.0                    py36_0  
atomicwrites              1.3.0                    py36_1  
attrs                     19.3.0                     py_0  
blas                      1.0                         mkl  
brevitas                  0.2.0a0                  pypi_0    pypi
c-ares                    1.15.0            h7b6447c_1001  
ca-certificates           2019.11.27                    0  
certifi                   2019.11.28               py36_0  
cudatoolkit               10.1.243             h6bb024c_0  
cudnn                     7.6.5                cuda10.1_0  
cupti                     10.1.168                      0  
docrep                    0.2.7                    pypi_0    pypi
freetype                  2.8                  hab7d2ae_1  
gast                      0.3.2                      py_0  
google-pasta              0.1.8                      py_0  
grpcio                    1.14.1           py36h9ba97e2_0  
h5py                      2.9.0            py36h7918eee_0  
hdf5                      1.10.4               hb1b8bf9_0  
importlib_metadata        1.3.0                    py36_0  
intel-openmp              2019.4                      243  
jpeg                      9b                   h024ee3a_2  
keras-applications        1.0.8                      py_0  
keras-preprocessing       1.1.0                      py_1  
libgcc-ng                 9.1.0                hdf63c60_0  
libgfortran-ng            7.3.0                hdf63c60_0  
libpng                    1.6.37               hbc83047_0  
libprotobuf               3.11.2               hd408876_0  
libstdcxx-ng              9.1.0                hdf63c60_0  
libtiff                   4.1.0                h2733197_0  
markdown                  3.1.1                    py36_0  
mkl                       2019.4                      243  
mkl-service               2.3.0            py36he904b0f_0  
mkl_fft                   1.0.15           py36ha843d7b_0  
mkl_random                1.1.0            py36hd6b4f25_0  
more-itertools            8.0.2                      py_0  
natsort                   7.0.0                    pypi_0    pypi
ninja                     1.9.0            py36hfd86e86_0  
numpy                     1.18.1           py36h4f9e942_0  
numpy-base                1.18.1           py36hde5b4d6_0  
olefile                   0.46                       py_0  
openssl                   1.0.2u               h7b6447c_0  
packaging                 20.0                       py_0  
pillow                    4.2.1            py36h9119f52_0  
pip                       19.3.1                   py36_0  
pluggy                    0.13.1                   py36_0  
protobuf                  3.11.2           py36he6710b0_0  
py                        1.8.1                      py_0  
pyparsing                 2.4.6                      py_0  
pytest                    5.0.1                    py36_0  
python                    3.6.0                         0  
pytorch                   1.4.0           py3.6_cuda10.1.243_cudnn7.6.3_0    pytorch
readline                  6.2                           2  
scipy                     1.3.2            py36h7c811a0_0  
setuptools                44.0.0                   py36_0  
six                       1.13.0                   py36_0  
sqlite                    3.13.0                        0  
tensorboard               1.14.0           py36hf484d3e_0  
tensorboardx              2.0                      pypi_0    pypi
tensorflow                1.14.0          gpu_py36h3fb9ad6_0  
tensorflow-base           1.14.0          gpu_py36he45bfe2_0  
tensorflow-estimator      1.14.0                     py_0  
tensorflow-gpu            1.14.0               h0d30ee6_0  
termcolor                 1.1.0                    py36_1  
tk                        8.5.18                        0  
torchprofile              0.0.1                    pypi_0    pypi
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.5.0                py36_cu101    pytorch
tqdm                      4.41.1                   pypi_0    pypi
wcwidth                   0.1.7                    py36_0  
werkzeug                  0.16.0                     py_0  
wheel                     0.33.6                   py36_0  
wrapt                     1.11.2           py36h7b6447c_0  
xz                        5.2.4                h14c3975_4  
zipp                      0.6.0                      py_0  
zlib                      1.2.11               h7b6447c_3  
zstd                      1.3.7                h0b5b093_0  

No handlers found!!

Thanks for your good job, but I found some problems:

No handlers found: "aten::group_norm". Skipped.
No handlers found: "aten::floor_divide". Skipped.
No handlers found: "aten::upsample_bilinear2d". Skipped.

No handlers found

UserWarning: No handlers found: "aten::gelu". Skipped.
UserWarning: No handlers found: "aten::reshape". Skipped.
UserWarning: No handlers found: "aten::permute". Skipped

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.