zhijian-liu / torchprofile Goto Github PK
View Code? Open in Web Editor NEWA general and accurate MACs / FLOPs profiler for PyTorch models
Home Page: https://pypi.org/project/torchprofile/
License: MIT License
A general and accurate MACs / FLOPs profiler for PyTorch models
Home Page: https://pypi.org/project/torchprofile/
License: MIT License
nn.Linear layer is always used in Transformer-like models.
Should I add this part when I use torchprofile to compute flops for a transformer model?
No handlers found: "aten::unsqueeze"
No handlers found: "aten::pow".
No handlers found: "aten::sign"
No handlers found: "aten::abs"
No handlers found: "aten::fft_fft"
No handlers found: "aten::real"
No handlers found: "aten::imag"
No handlers found: "aten::unfold"
No handlers found: "aten::split"
No handlers found: "aten::einsum".
No handlers found: "aten::permute".
hope to support these ops, thanks!
Thank you for your great work!
I encountered this warning when using torchprofile. It seems like torchprofile cannot calculate the GFLOPS of torch.nn.functional.grid_sample
:
UserWarning: No handlers found: "aten::grid_sampler". Skipped.
Hello,
I'm working with NsgaNetV2 code (https://github.com/mikelzc1990/nsganetv2); that codebase is for NAS, there's a supernetwork with shared weights, and subnetworks of it that are evaluated. Since the supernetwork is shared, subnetworks are defined dynamically. For each subnetwork, FLOPS count needs to be calculated.
I faced a problem with inconsistent FLOPS counts; I think I found the root cause: even when a layer is not used in the subnetwork, it is still counted towards the FLOPS count. This is problematic, but things get worse. To explain using an example, suppose I have 3 configs of the subnetworks:
config = {'ks': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 'e': [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 'd': [2, 2, 2, 2, 2], 'r': 192}
config_big_long = {"ks": [7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], "e": [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], "d": [4, 4, 4, 4, 4], "r": 192}
config_small_long = {"ks": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "e": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "d": [4, 4, 4, 4, 4], "r": 192}
I want to count FLOPs for the subnetwork defined by config
. That subnetwork is shallower than both config_big_long
and config_small_long
. This means that the parameters of the unused-by-config
layers will depend on which of the other two configs was activated before.
At least that's what I observe:
> self.engine.set_active_subnet(ks=config_big_long['ks'], e=config_big_long['e'], d=config_big_long['d'])
> self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
> int(profile_macs(copy.deepcopy(self.engine.get_active_subnet(True)), inputs.cuda()))
130925344
> self.engine.set_active_subnet(ks=config_small_long['ks'], e=config_small_long['e'], d=config_small_long['d'])
> self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
> int(profile_macs(copy.deepcopy(self.engine.get_active_subnet(True)), inputs.cuda()))
89421520
This behaviour is strange, because from the code of your package it seems that dynamic tracing is used, so I'm not sure why this happens.
Hi, I was doing profiling on the recently released StarGAN-V2 ( https://github.com/clovaai/stargan-v2 ) pre-trained generator and style_encoder networks. However, certain operations used in these networks are missing from torch profile. Following are the missing handlers:
Please suggest where can I add these operations in profile.py to do an accurate profiling?
hope to support "prim::pythonop" and "prim::tupleunpack" operations
Thanks and great tool!
I test a simple model as following:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torchprofile import profile_macs
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5, padding=2, bias=False)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 10, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(10)
def forward(self, x):
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x += residual
return x
if __name__=="__main__":
model = Net().cuda()
# count parameters
summary(model, (1, 28, 28))
# count FLOPs
inputs = torch.randn(1, 1, 28, 28).cuda()
flops = profile_macs(model, inputs)
print(f"FLOPs : {flops}")
And the results are
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 10, 28, 28] 250
BatchNorm2d-2 [-1, 10, 28, 28] 20
Conv2d-3 [-1, 10, 28, 28] 900
BatchNorm2d-4 [-1, 10, 28, 28] 20
Net-5 [-1, 10, 28, 28] 0
================================================================
Total params: 1,190
Trainable params: 1,190
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.30
Params size (MB): 0.00
Estimated Total Size (MB): 0.31
----------------------------------------------------------------
FLOPs : 901600
Obviously, the results only count the macs of convolution layer by 10x28x28x5x5x1+10x28x28x3x3x10, but ignores the macs of batch normalization layer and shortcut addition.
Currently there are no handlers for aten::rsub and aten::squeeze, would be nice to have them!
I'm getting an overflow encountered
error and profile gives me a negative number of MACS. is this a torchprofile issue or an issue with my model?
I was testing different tools for computing FLOPs and MACs for CNN models and came across this repo. It seems like the FLOPs computation is not supported yet.
Was wondering if it's in future dev plan and if I could help with it? Thanks!
hope to support "aten::zeros" and "aten::lstm" ops, thanks!
Thank you for your great work!
I encountered this warning when using torchprofile. It seems like torchprofile cannot calculate the MACs of
UserWarning: No handlers found: "torchvision::roi_align". Skipped
UserWarning: No handlers found: "aten::sqrt". Skipped
UserWarning: No handlers found: "aten::log2". Skipped
UserWarning: No handlers found: "aten::reciprocal". Skipped.
UserWarning: No handlers found: "aten::scalarimplicit". Skipped.
UserWarning: No handlers found: "aten::max". Skipped.
By the way, apart from the MACs, Is there any way to calculate the model memory use?
Hi, Thank you for the awesome tool!
profile_macs
is not working at PyTorch 1.4
The key error message is
AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'
Thank you in advance!
Here is my entire error log:
Traceback (most recent call last):
File "gen_summary.py", line 54, in <module>
macs = profile_macs(model, inputs)
File "/home/shkim/.conda/envs/torch1.4/lib/python3.6/site-packages/torchprofile/profile.py", line 12, in profile_macs
graph = trace(model, args, kwargs)
File "/home/shkim/.conda/envs/torch1.4/lib/python3.6/site-packages/torchprofile/utils/trace.py", line 17, in trace
trace, _ = torch.jit.get_trace_graph(Flatten(model), args, kwargs)
AttributeError: module 'torch.jit' has no attribute 'get_trace_graph'
my env setting as follows:
# packages in environment at **:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_tflow_select 2.1.0 gpu
absl-py 0.8.1 py36_0
astor 0.8.0 py36_0
atomicwrites 1.3.0 py36_1
attrs 19.3.0 py_0
blas 1.0 mkl
brevitas 0.2.0a0 pypi_0 pypi
c-ares 1.15.0 h7b6447c_1001
ca-certificates 2019.11.27 0
certifi 2019.11.28 py36_0
cudatoolkit 10.1.243 h6bb024c_0
cudnn 7.6.5 cuda10.1_0
cupti 10.1.168 0
docrep 0.2.7 pypi_0 pypi
freetype 2.8 hab7d2ae_1
gast 0.3.2 py_0
google-pasta 0.1.8 py_0
grpcio 1.14.1 py36h9ba97e2_0
h5py 2.9.0 py36h7918eee_0
hdf5 1.10.4 hb1b8bf9_0
importlib_metadata 1.3.0 py36_0
intel-openmp 2019.4 243
jpeg 9b h024ee3a_2
keras-applications 1.0.8 py_0
keras-preprocessing 1.1.0 py_1
libgcc-ng 9.1.0 hdf63c60_0
libgfortran-ng 7.3.0 hdf63c60_0
libpng 1.6.37 hbc83047_0
libprotobuf 3.11.2 hd408876_0
libstdcxx-ng 9.1.0 hdf63c60_0
libtiff 4.1.0 h2733197_0
markdown 3.1.1 py36_0
mkl 2019.4 243
mkl-service 2.3.0 py36he904b0f_0
mkl_fft 1.0.15 py36ha843d7b_0
mkl_random 1.1.0 py36hd6b4f25_0
more-itertools 8.0.2 py_0
natsort 7.0.0 pypi_0 pypi
ninja 1.9.0 py36hfd86e86_0
numpy 1.18.1 py36h4f9e942_0
numpy-base 1.18.1 py36hde5b4d6_0
olefile 0.46 py_0
openssl 1.0.2u h7b6447c_0
packaging 20.0 py_0
pillow 4.2.1 py36h9119f52_0
pip 19.3.1 py36_0
pluggy 0.13.1 py36_0
protobuf 3.11.2 py36he6710b0_0
py 1.8.1 py_0
pyparsing 2.4.6 py_0
pytest 5.0.1 py36_0
python 3.6.0 0
pytorch 1.4.0 py3.6_cuda10.1.243_cudnn7.6.3_0 pytorch
readline 6.2 2
scipy 1.3.2 py36h7c811a0_0
setuptools 44.0.0 py36_0
six 1.13.0 py36_0
sqlite 3.13.0 0
tensorboard 1.14.0 py36hf484d3e_0
tensorboardx 2.0 pypi_0 pypi
tensorflow 1.14.0 gpu_py36h3fb9ad6_0
tensorflow-base 1.14.0 gpu_py36he45bfe2_0
tensorflow-estimator 1.14.0 py_0
tensorflow-gpu 1.14.0 h0d30ee6_0
termcolor 1.1.0 py36_1
tk 8.5.18 0
torchprofile 0.0.1 pypi_0 pypi
torchsummary 1.5.1 pypi_0 pypi
torchvision 0.5.0 py36_cu101 pytorch
tqdm 4.41.1 pypi_0 pypi
wcwidth 0.1.7 py36_0
werkzeug 0.16.0 py_0
wheel 0.33.6 py36_0
wrapt 1.11.2 py36h7b6447c_0
xz 5.2.4 h14c3975_4
zipp 0.6.0 py_0
zlib 1.2.11 h7b6447c_3
zstd 1.3.7 h0b5b093_0
Thanks for your good job, but I found some problems:
No handlers found: "aten::group_norm". Skipped.
No handlers found: "aten::floor_divide". Skipped.
No handlers found: "aten::upsample_bilinear2d". Skipped.
As in macs = profile_macs(model, inputs)
- is there a way to also obtain the model's output given inputs, without running the forward pass again?
def profile_macs(model, args=(), kwargs=None, reduction=sum):
results = dict()
graph = trace(model, args, kwargs)
UserWarning: No handlers found: "aten::gelu". Skipped.
UserWarning: No handlers found: "aten::reshape". Skipped.
UserWarning: No handlers found: "aten::permute". Skipped
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.