The purpose of this script is to provide a convenient and efficient way to hook into PyTorch model layers during the forward pass(backward is coming), allowing you to capture or modify the inputs and outputs of specified layers.
Inspired by the nethook.py here, but DeepHook is much simpler and easier to understand.
Sure, let's provide a more detailed description of the script:
-
Trace
Class: This class acts as a hook into a specific layer of a PyTorch model during a forward pass. Upon initialization, it takes in a model, the layer name, and optional settings (whether to retain input and output, and an optional function to modify the output).The class registers a forward hook on the specified layer. When the layer is called during a forward pass, the hook captures the input and output of the layer. If specified, it also modifies the output using the provided
edit_output
function.If
retain_input
isTrue
, the hook stores the input to the layer inself.input
. Ifretain_output
isTrue
, it stores the output (potentially modified byedit_output
) inself.output
.The
Trace
class is a context manager, meaning it can be used in awith
statement. When thewith
block is entered, the__enter__
method is called, which simply returnsself
. When thewith
block is exited, the__exit__
method is called, which removes the registered hook, ensuring no leftover hooks remain attached to the model. -
TraceMultiple
Class: This class is a context manager for hooking into multiple layers of a PyTorch model simultaneously. It accepts a model and a dictionary mapping layer names to a tuple of settings (whether to retain output and input, an optional function to modify the output).The
TraceMultiple
class creates aTrace
object for each layer and manages them using anExitStack
from thecontextlib
module. This ensures that all hooks are properly removed when thewith
block is exited, even if an error occurs during the forward pass.Like
Trace
,TraceMultiple
is a context manager, so it can be used in awith
statement. The__enter__
method enters theExitStack
context and also enters the context of eachTrace
object (i.e., registers all the hooks). The__exit__
method ensures allTrace
contexts are exited (i.e., all hooks are removed), and then exits theExitStack
context.
Here's a usage example:
def edit_fn(output):
return output + 1 # a simple function that adds 1 to the output
layer_settings = {
'transformer.wpe': (True, False, None), # retain output, don't retain input, no edit function
'transformer.h.0': (True, True, edit_fn), # retain output and input, use edit_fn to edit output
# Add more layers as needed...
}
with TraceMultiple(model, layer_settings) as tm:
_ = model(**encoded_input)
# Access the input and output of each hooked layer
wpe_output = tm['transformer.wpe'].output
h0_input = tm['transformer.h.0'].input
h0_output = tm['transformer.h.0'].output