Git Product home page Git Product logo

deephook's Introduction

The purpose of this script is to provide a convenient and efficient way to hook into PyTorch model layers during the forward pass(backward is coming), allowing you to capture or modify the inputs and outputs of specified layers.

Inspired by the nethook.py here, but DeepHook is much simpler and easier to understand.

how does this work?

Sure, let's provide a more detailed description of the script:

  1. Trace Class: This class acts as a hook into a specific layer of a PyTorch model during a forward pass. Upon initialization, it takes in a model, the layer name, and optional settings (whether to retain input and output, and an optional function to modify the output).

    The class registers a forward hook on the specified layer. When the layer is called during a forward pass, the hook captures the input and output of the layer. If specified, it also modifies the output using the provided edit_output function.

    If retain_input is True, the hook stores the input to the layer in self.input. If retain_output is True, it stores the output (potentially modified by edit_output) in self.output.

    The Trace class is a context manager, meaning it can be used in a with statement. When the with block is entered, the __enter__ method is called, which simply returns self. When the with block is exited, the __exit__ method is called, which removes the registered hook, ensuring no leftover hooks remain attached to the model.

  2. TraceMultiple Class: This class is a context manager for hooking into multiple layers of a PyTorch model simultaneously. It accepts a model and a dictionary mapping layer names to a tuple of settings (whether to retain output and input, an optional function to modify the output).

    The TraceMultiple class creates a Trace object for each layer and manages them using an ExitStack from the contextlib module. This ensures that all hooks are properly removed when the with block is exited, even if an error occurs during the forward pass.

    Like Trace, TraceMultiple is a context manager, so it can be used in a with statement. The __enter__ method enters the ExitStack context and also enters the context of each Trace object (i.e., registers all the hooks). The __exit__ method ensures all Trace contexts are exited (i.e., all hooks are removed), and then exits the ExitStack context.

Here's a usage example:

def edit_fn(output):
    return output + 1  # a simple function that adds 1 to the output

layer_settings = {
    'transformer.wpe': (True, False, None),  # retain output, don't retain input, no edit function
    'transformer.h.0': (True, True, edit_fn),  # retain output and input, use edit_fn to edit output
    # Add more layers as needed...
}

with TraceMultiple(model, layer_settings) as tm:
    _ = model(**encoded_input)

# Access the input and output of each hooked layer
wpe_output = tm['transformer.wpe'].output
h0_input = tm['transformer.h.0'].input
h0_output = tm['transformer.h.0'].output

deephook's People

Contributors

kir486680 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.