Comments (4)
Thank for your works!
I see that you implement LRP for vgg model. But vgg is simple model with single Sequential and does not have residual connection. Could you help me to implement LRP for complex model, such as ResNet?
Thank you so much!
I implement the resnet convert as follows:
import torch
import torchvision
from lrp.conv import Conv2d
from lrp.linear import Linear
from lrp.sequential import Sequential, Bottleneck
conversion_table = {
'Linear': Linear,
'Conv2d': Conv2d
}
# # # # # Convert torch.models.resnetxx to lrp model
def convert_resnet(module, modules=None):
# First time
if modules is None:
modules = []
for m in module.children():
convert_resnet(m, modules=modules)
# if isinstance(m, torch.nn.Sequential):
# break
# Vgg model has a flatten, which is not represented as a module
# so this loop doesn't pick it up.
# This is a hack to make things work.
if isinstance(m, torch.nn.AdaptiveAvgPool2d):
modules.append(torch.nn.Flatten())
sequential = Sequential(*modules)
return sequential
# Recursion
if isinstance(module, torch.nn.Sequential):
for m in module.children():
convert_resnet(m, modules=modules)
elif isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
class_name = module.__class__.__name__
lrp_module = conversion_table[class_name].from_torch(module)
modules.append(lrp_module)
# maxpool is handled with gradient for the moment
elif isinstance(module, torch.nn.ReLU):
# avoid inplace operations. They might ruin PatternNet pattern
# computations
modules.append(torch.nn.ReLU())
elif isinstance(module, torchvision.models.resnet.Bottleneck):
# For torchvision Bottleneck
bottleneck = Bottleneck()
bottleneck.conv1 = Conv2d.from_torch(module.conv1)
bottleneck.conv2 = Conv2d.from_torch(module.conv2)
bottleneck.conv3 = Conv2d.from_torch(module.conv3)
bottleneck.bn1 = module.bn1
bottleneck.bn2 = module.bn2
bottleneck.bn3 = module.bn3
bottleneck.relu = torch.nn.ReLU()
if module.downsample is not None:
bottleneck.downsample = module.downsample
bottleneck.downsample[0] = Conv2d.from_torch(module.downsample[0])
modules.append(bottleneck)
else:
modules.append(module)
and edit the sequential.py as follows:
import torch
from . import Linear, Conv2d
from .maxpool import MaxPool2d
from .functional.utils import normalize
def grad_decorator_fn(module):
"""
Currently not used but can be used for debugging purposes.
"""
def fn(x):
return normalize(x)
return fn
avoid_normalization_on = ['relu', 'maxp']
def do_normalization(rule, module):
if "pattern" not in rule.lower(): return False
return not str(module)[:4].lower() in avoid_normalization_on
def is_kernel_layer(module):
return isinstance(module, Conv2d) or isinstance(module, Linear) or isinstance(module, Bottleneck)
def is_rule_specific_layer(module):
return isinstance(module, MaxPool2d)
class Sequential(torch.nn.Sequential):
def forward(self, input, explain=False, rule="epsilon", pattern=None):
if not explain: return super(Sequential, self).forward(input)
first = True
# copy references for user to be able to reuse patterns
if pattern is not None: pattern = list(pattern)
for module in self:
if do_normalization(rule, module):
input.register_hook(grad_decorator_fn(module))
if is_kernel_layer(module):
P = None
if pattern is not None:
P = pattern.pop(0)
input = module.forward(input, explain=True, rule=rule, pattern=P)
elif is_rule_specific_layer(module):
input = module.forward(input, explain=True, rule=rule)
else: # Use gradient as default for remaining layer types
input = module(input)
first = False
if do_normalization(rule, module):
input.register_hook(grad_decorator_fn(module))
return input
class Bottleneck(torch.nn.Module):
def __init__(self):
super(Bottleneck, self).__init__()
self.downsample = None
def forward(self, x, explain=True, rule="epsilon", pattern=None):
identity = x
if pattern is not None:
out = self.conv1(x, explain=explain, rule=rule, pattern=pattern[0])
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out, explain=explain, rule=rule, pattern=pattern[1])
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out, explain=explain, rule=rule, pattern=pattern[2])
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample[0](x, explain, rule, pattern=pattern[3])
identity = self.downsample[1](identity)
else:
out = self.conv1(x, explain=explain, rule=rule)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out, explain=explain, rule=rule)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out, explain=explain, rule=rule)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample[0](x, explain, rule)
identity = self.downsample[1](identity)
out += identity
out = self.relu(out)
return out
For patternnet, also need to modify the _fit_pattern function in patterns.py:
def _fit_pattern(model, train_loader, max_iter, device, mask_fn = lambda y: torch.ones_like(y)):
stats_x = []
stats_y = []
stats_xy = []
weights = []
cnt = []
cnt_all = []
first = True
for b, (x, _) in enumerate(tqdm(train_loader)):
x = x.to(device)
i = 0
for m in model:
# For Bottleneck
if isinstance(m, Bottleneck):
if first:
stats_x.append([])
stats_y.append([])
stats_xy.append([])
weights.append([])
y = m.conv1(x)
mask = mask_fn(y).float().to(device)
if m.conv1.bias is not None:
y_wo_bias = y - m.conv1.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv1, x, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][0].update(x_, cnt_)
stats_y[i][0].update(y_.sum(0), cnt_all_)
stats_xy[i][0].update(xy_, cnt_)
x1 = y.clone()
x1 = m.bn1(x1)
x1 = m.relu(x1)
y = m.conv2(x1)
mask = mask_fn(y).float().to(device)
if m.conv2.bias is not None:
y_wo_bias = y - m.conv2.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv2, x1, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][1].update(x_, cnt_)
stats_y[i][1].update(y_.sum(0), cnt_all_)
stats_xy[i][1].update(xy_, cnt_)
x2 = y.clone()
x2 = m.bn2(x2)
x2 = m.relu(x2)
y = m.conv3(x2)
mask = mask_fn(y).float().to(device)
if m.conv3.bias is not None:
y_wo_bias = y - m.conv3.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv3, x2, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][2].update(x_, cnt_)
stats_y[i][2].update(y_.sum(0), cnt_all_)
stats_xy[i][2].update(xy_, cnt_)
y = m.bn3(y)
if m.downsample is not None:
identity = m.downsample[0](x)
mask = mask_fn(identity).float().to(device)
if m.downsample[0].bias is not None:
y_wo_bias = y - m.downsample[0].bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.downsample[0], x, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][3].update(x_, cnt_)
stats_y[i][3].update(y_.sum(0), cnt_all_)
stats_xy[i][3].update(xy_, cnt_)
identity = m.downsample[1](identity)
y += identity
x = m.relu(y)
i += 1
continue
y = m(x) # Note, this includes bias.
if not (isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d)):
x = y.clone()
continue
mask = mask_fn(y).float().to(device)
if m.bias is not None:
if isinstance(m, torch.nn.Conv2d):
y_wo_bias = y - m.bias.view(-1, 1, 1)
else:
y_wo_bias = y - m.bias.clone()
else:
y_wo_bias = y
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m, x, y_wo_bias, mask)
if first:
stats_x.append(RunningMean(x_.shape, device))
stats_y.append(RunningMean(y_.shape, device)) # Use all y
stats_xy.append(RunningMean(xy_.shape, device))
weights.append((w, w_fn))
stats_x[i].update(x_, cnt_)
stats_y[i].update(y_.sum(0), cnt_all_)
stats_xy[i].update(xy_, cnt_)
x = y.clone()
i += 1
first = False
if max_iter is not None and b+1 == max_iter: break
def pattern(x_mean, y_mean, xy_mean, W2d):
x_ = x_mean.value
y_ = y_mean.value
xy_ = xy_mean.value
W, w_fn = W2d
ExEy = x_ * y_
cov_xy = xy_ - ExEy # [in, out]
w_cov_xy = torch.diag(W @ cov_xy) # [out,]
A = safe_divide(cov_xy, w_cov_xy[None, :])
A = w_fn(A) # Reshape to original kernel size
return A
# patterns = [pattern(*vars) for vars in zip(stats_x, stats_y, stats_xy, weights)]
patterns = []
for vars in zip(stats_x, stats_y, stats_xy, weights):
if isinstance(vars[0], RunningMean):
patterns.append(pattern(*vars))
else:
patterns_sub = []
for vars_sub in zip(vars[0], vars[1], vars[2], vars[3]):
patterns_sub.append(pattern(*vars_sub))
patterns.append(patterns_sub)
return patterns
The LRP for the adding manipulation is not added yet, will probably need to consider implementing this.
from torchlrp.
Thank you for your posting @sdw95927 .
I used these lines of code for generating ResNet heatmaps.
The problem is the heatmaps for ResNet are not so meaningful.
I am attaching the heatmaps generated for pretrained VGG and ResNet as below. Any idea?
from torchlrp.
I can see the patterns from ResNet too, just not as clear as VGG. I think it's mainly due to the complex structure, such as residual connection, in ResNet, whereas VGG is simple and straightforward.
from torchlrp.
@miladsikaroudi Can you please share what you did to make this work with ResNet?
from torchlrp.
Related Issues (7)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchlrp.