Git Product home page Git Product logo

Comments (4)

sdw95927 avatar sdw95927 commented on May 20, 2024 4

Thank for your works!
I see that you implement LRP for vgg model. But vgg is simple model with single Sequential and does not have residual connection. Could you help me to implement LRP for complex model, such as ResNet?
Thank you so much!

I implement the resnet convert as follows:

import torch
import torchvision
from lrp.conv       import Conv2d 
from lrp.linear     import Linear
from lrp.sequential import Sequential, Bottleneck

conversion_table = { 
        'Linear': Linear,
        'Conv2d': Conv2d
    }

# # # # # Convert torch.models.resnetxx to lrp model
def convert_resnet(module, modules=None):
    # First time
    if modules is None: 
        modules = []
        for m in module.children():
            convert_resnet(m, modules=modules)
            
            # if isinstance(m, torch.nn.Sequential):
            #     break
            
            # Vgg model has a flatten, which is not represented as a module
            # so this loop doesn't pick it up.
            # This is a hack to make things work.
            if isinstance(m, torch.nn.AdaptiveAvgPool2d): 
                modules.append(torch.nn.Flatten())

        sequential = Sequential(*modules)
        return sequential

    # Recursion
    if isinstance(module, torch.nn.Sequential): 
        for m in module.children():
            convert_resnet(m, modules=modules)

    elif isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
        class_name = module.__class__.__name__
        lrp_module = conversion_table[class_name].from_torch(module)
        modules.append(lrp_module)
    # maxpool is handled with gradient for the moment

    elif isinstance(module, torch.nn.ReLU): 
        # avoid inplace operations. They might ruin PatternNet pattern
        # computations
        modules.append(torch.nn.ReLU())
    elif isinstance(module, torchvision.models.resnet.Bottleneck):
        # For torchvision Bottleneck
        bottleneck = Bottleneck()
        bottleneck.conv1 = Conv2d.from_torch(module.conv1)
        bottleneck.conv2 = Conv2d.from_torch(module.conv2)
        bottleneck.conv3 = Conv2d.from_torch(module.conv3)
        bottleneck.bn1 = module.bn1
        bottleneck.bn2 = module.bn2
        bottleneck.bn3 = module.bn3
        bottleneck.relu = torch.nn.ReLU()
        if module.downsample is not None:
            bottleneck.downsample = module.downsample
            bottleneck.downsample[0] = Conv2d.from_torch(module.downsample[0])
        modules.append(bottleneck)
    else:
        modules.append(module)

and edit the sequential.py as follows:

import torch

from . import Linear, Conv2d
from .maxpool import MaxPool2d
from .functional.utils import normalize

def grad_decorator_fn(module):
    """
        Currently not used but can be used for debugging purposes.
    """
    def fn(x): 
        return normalize(x)
    return fn

avoid_normalization_on = ['relu', 'maxp']
def do_normalization(rule, module):
    if "pattern" not in rule.lower(): return False
    return not str(module)[:4].lower() in avoid_normalization_on

def is_kernel_layer(module):
    return isinstance(module, Conv2d) or isinstance(module, Linear) or isinstance(module, Bottleneck)

def is_rule_specific_layer(module):
    return isinstance(module, MaxPool2d)

class Sequential(torch.nn.Sequential):
    def forward(self, input, explain=False, rule="epsilon", pattern=None):
        if not explain: return super(Sequential, self).forward(input)

        first = True

        # copy references for user to be able to reuse patterns
        if pattern is not None: pattern = list(pattern) 

        for module in self:
            if do_normalization(rule, module):
                input.register_hook(grad_decorator_fn(module))

            if is_kernel_layer(module): 
                P = None
                if pattern is not None: 
                    P = pattern.pop(0)
                input = module.forward(input, explain=True, rule=rule, pattern=P)

            elif is_rule_specific_layer(module):
                input = module.forward(input, explain=True, rule=rule)

            else: # Use gradient as default for remaining layer types
                input = module(input)
            first = False

        if do_normalization(rule, module): 
            input.register_hook(grad_decorator_fn(module))

        return input

class Bottleneck(torch.nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()
        self.downsample = None
    
    def forward(self, x, explain=True, rule="epsilon", pattern=None):
        identity = x
        
        if pattern is not None:
            out = self.conv1(x, explain=explain, rule=rule, pattern=pattern[0])
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out, explain=explain, rule=rule, pattern=pattern[1])
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out, explain=explain, rule=rule, pattern=pattern[2])
            out = self.bn3(out)

            if self.downsample is not None:
                identity = self.downsample[0](x, explain, rule, pattern=pattern[3])
                identity = self.downsample[1](identity)
        else:
            out = self.conv1(x, explain=explain, rule=rule)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out, explain=explain, rule=rule)
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out, explain=explain, rule=rule)
            out = self.bn3(out)

            if self.downsample is not None:
                identity = self.downsample[0](x, explain, rule)
                identity = self.downsample[1](identity)
            
        out += identity
        out = self.relu(out)

        return out

For patternnet, also need to modify the _fit_pattern function in patterns.py:

def _fit_pattern(model, train_loader, max_iter, device, mask_fn = lambda y: torch.ones_like(y)):
    stats_x     = [] 
    stats_y     = []
    stats_xy    = []
    weights     = []
    cnt         = []
    cnt_all     = []

    first = True
    for b, (x, _) in enumerate(tqdm(train_loader)): 
        x = x.to(device)

        i = 0
        for m in model:
            # For Bottleneck
            if isinstance(m, Bottleneck):
                if first:
                    stats_x.append([])
                    stats_y.append([])
                    stats_xy.append([])
                    weights.append([])
                    
                y = m.conv1(x)
                mask = mask_fn(y).float().to(device)
                if m.conv1.bias is not None:
                    y_wo_bias = y - m.conv1.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv1, x, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][0].update(x_, cnt_)
                stats_y[i][0].update(y_.sum(0), cnt_all_)
                stats_xy[i][0].update(xy_, cnt_)
                
                x1 = y.clone()
                x1 = m.bn1(x1)
                x1 = m.relu(x1)
                
                y = m.conv2(x1)
                mask = mask_fn(y).float().to(device)
                if m.conv2.bias is not None:
                    y_wo_bias = y - m.conv2.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv2, x1, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][1].update(x_, cnt_)
                stats_y[i][1].update(y_.sum(0), cnt_all_)
                stats_xy[i][1].update(xy_, cnt_)
                
                x2 = y.clone()
                x2 = m.bn2(x2)
                x2 = m.relu(x2)
                
                y = m.conv3(x2)
                mask = mask_fn(y).float().to(device)
                if m.conv3.bias is not None:
                    y_wo_bias = y - m.conv3.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv3, x2, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][2].update(x_, cnt_)
                stats_y[i][2].update(y_.sum(0), cnt_all_)
                stats_xy[i][2].update(xy_, cnt_)
                
                y = m.bn3(y)
                
                if m.downsample is not None:
                    identity = m.downsample[0](x)
                    mask = mask_fn(identity).float().to(device)
                    if m.downsample[0].bias is not None:
                        y_wo_bias = y - m.downsample[0].bias.view(-1, 1, 1)
                    else:
                        y_wo_bias = y.clone()
                    cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.downsample[0], x, y_wo_bias, mask)
                    if first:
                        stats_x[i].append(RunningMean(x_.shape, device))
                        stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                        stats_xy[i].append(RunningMean(xy_.shape, device))
                        weights[i].append((w, w_fn))
                    stats_x[i][3].update(x_, cnt_)
                    stats_y[i][3].update(y_.sum(0), cnt_all_)
                    stats_xy[i][3].update(xy_, cnt_)
                    identity = m.downsample[1](identity)
                
                y += identity
                x = m.relu(y)
                i += 1
                continue
                
            y = m(x) # Note, this includes bias.
            
            if not (isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d)): 
                x = y.clone()
                continue
            
            mask = mask_fn(y).float().to(device)
            if m.bias is not None:
                if isinstance(m, torch.nn.Conv2d): 
                    y_wo_bias = y - m.bias.view(-1, 1, 1) 
                else:                              
                    y_wo_bias = y - m.bias.clone()
            else:
                y_wo_bias = y

            cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m, x, y_wo_bias, mask)

            if first:
                stats_x.append(RunningMean(x_.shape, device))
                stats_y.append(RunningMean(y_.shape, device)) # Use all y
                stats_xy.append(RunningMean(xy_.shape, device))
                weights.append((w, w_fn))

            stats_x[i].update(x_, cnt_)
            stats_y[i].update(y_.sum(0), cnt_all_)
            stats_xy[i].update(xy_, cnt_)

            x = y.clone()
            i += 1

            
        first = False

        if max_iter is not None and b+1 == max_iter: break

    def pattern(x_mean, y_mean, xy_mean, W2d):
        x_  = x_mean.value
        y_  = y_mean.value
        xy_ = xy_mean.value

        W, w_fn = W2d
        ExEy = x_ * y_
        cov_xy = xy_ - ExEy # [in, out]

        w_cov_xy = torch.diag(W @ cov_xy) # [out,]

        A = safe_divide(cov_xy, w_cov_xy[None, :])
        A = w_fn(A) # Reshape to original kernel size

        return A
        
    # patterns = [pattern(*vars) for vars in zip(stats_x, stats_y, stats_xy, weights)]
    patterns = []
    for vars in zip(stats_x, stats_y, stats_xy, weights):
        if isinstance(vars[0], RunningMean):
            patterns.append(pattern(*vars))
        else:
            patterns_sub = []
            for vars_sub in zip(vars[0], vars[1], vars[2], vars[3]):
                patterns_sub.append(pattern(*vars_sub))
            patterns.append(patterns_sub)
    return patterns

The LRP for the adding manipulation is not added yet, will probably need to consider implementing this.

from torchlrp.

miladsikaroudi avatar miladsikaroudi commented on May 20, 2024

Thank you for your posting @sdw95927 .
I used these lines of code for generating ResNet heatmaps.
The problem is the heatmaps for ResNet are not so meaningful.
I am attaching the heatmaps generated for pretrained VGG and ResNet as below. Any idea?

VGG

RES

from torchlrp.

sdw95927 avatar sdw95927 commented on May 20, 2024

I can see the patterns from ResNet too, just not as clear as VGG. I think it's mainly due to the complex structure, such as residual connection, in ResNet, whereas VGG is simple and straightforward.

from torchlrp.

zah-tane avatar zah-tane commented on May 20, 2024

@miladsikaroudi Can you please share what you did to make this work with ResNet?

from torchlrp.

Related Issues (7)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.