circuit-sparsity / hook_utils.py
achyutarajaram's picture
Upload folder using huggingface_hub
179cd55 verified
"""Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds.
The full module has no exotic dependencies, but mirroring the definitions here
keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The
implementations below are copied with minor tweaks for readability so that code
written against :func:`hook_recorder`, :func:`hook_namespace`, and
:func:`torch_recompute_preserving_hook_context` behaves identically in both the
training and inference configurations.
"""
from __future__ import annotations
import re
from contextlib import contextmanager
from functools import partial
import torch
import torch.utils.checkpoint
class HookContext:
"""State container used by the hook helpers."""
def __init__(self) -> None:
self._reset()
self.curintervtransformer = lambda x: x
def _reset(self) -> None:
self.curcontext = None
self.curname = ""
self.curregex = None
self.curinterventions = None
self.save_grads = None
def _get_interventions(self):
return self.curintervtransformer(
self.curinterventions if self.curinterventions is not None else {}
)
@contextmanager
def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False):
"""Record tensors that pass through hooks matching ``regex``."""
assert self.curcontext is None, "reentrancy not allowed!"
try:
self.curcontext = {}
self.curregex = re.compile(regex)
self.curname = ""
self.curinterventions = interventions
self.save_grads = save_grads
yield self.curcontext
finally:
self._reset()
get_context()._reset()
@contextmanager
def hook_intervention_transform(self, intervention_transformer):
oldintervention_transformer = self.curintervtransformer
def compose(f, g):
return lambda x: f(g(x))
self.curintervtransformer = compose(
intervention_transformer,
self.curintervtransformer,
)
try:
yield
finally:
self.curintervtransformer = oldintervention_transformer
@contextmanager
def hook_namespace(self, name: str):
"""Temporarily push ``name`` onto the hook namespace stack."""
oldname = self.curname
self.curname = self.curname + name + "."
try:
yield
finally:
self.curname = oldname
def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor:
"""Optionally record ``tensor`` using the current namespace."""
curinterventions = self._get_interventions()
if curinterventions is not None:
key = self.curname + name
if key in curinterventions:
tensor = curinterventions[key](tensor)
if self.curcontext is not None and self.curregex.match(self.curname + name):
self.curcontext[self.curname + name] = tensor
if self.curcontext is not None and self.save_grads and tensor.requires_grad:
class _Grad(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor):
return input_tensor
@staticmethod
def backward(ctx, grad_output):
self.curcontext[self.curname + name + ".grad"] = grad_output
return grad_output
if self.curregex.match(self.curname + name + ".grad"):
tensor = _Grad.apply(tensor)
return tensor
def set_context(new_context: HookContext) -> None:
global context
context = new_context
def get_context() -> HookContext:
global context
return context
def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None):
"""Wrapper around :func:`torch.utils.checkpoint` that propagates hooks."""
oldcontext = get_context()
curcontext = HookContext()
curcontext.curcontext = (
dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None
)
curcontext.curregex = oldcontext.curregex
curcontext.curname = oldcontext.curname
curcontext.curinterventions = (
dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None
)
curcontext.save_grads = oldcontext.save_grads
is_recompute = False
def _f(curcontext: HookContext, *xs):
initcontext = get_context()
nonlocal is_recompute
set_context(curcontext)
try:
res = f(*xs)
if not is_recompute and oldcontext.curcontext is not None:
oldcontext.curcontext |= curcontext.curcontext
finally:
set_context(initcontext)
is_recompute = True
return res
res = torch.utils.checkpoint.checkpoint(
partial(_f, curcontext), *xs, use_reentrant=use_reentrant
)
return res
context = HookContext()
def hook_recorder(*a, **k):
return get_context().hook_recorder(*a, **k)
def hook_namespace(*a, **k):
return get_context().hook_namespace(*a, **k)
def hook_save(*a, **k):
return get_context().hook_save(*a, **k)
def hook_intervention_transform(*a, **k):
return get_context().hook_intervention_transform(*a, **k)