본문 바로가기
Python

[DL] PyTorch-Hook

by ds31x 2025. 4. 10.
728x90
반응형

PyTorch의 hook

  • Neural Network 내부의 계산 과정을 관찰하거나,
  • 특정 시점에서 개입할 수 있도록 해주는 기능 (사실은 function 또는 instance method임).

을 제공함.

 

이를 통해 forward 중간 출력 및 backward 에서의 gradient, 또는 입력 값 자체를 가로채어 확인하거나 수정할 수 있음.

Hook이 제대로 동작하기 위해선
forward를 직접 호출해선 안됨.

 

관련 내용은 다음 URL에서 __call__ 과 forward 관련부분 읽어볼것:
2024.04.12 - [Python] - [PyTorch] Custom Model 과 torch.nn.Module의 메서드들.

 

[PyTorch] Custom Model 과 torch.nn.Module의 메서드들.

Custom Model 만들기0. nn.Module torch.nn.Module은 PyTorch에서 모든 신경망 모델과 계층의 기반이 되는 클래스임.Custom Model (사용자 정의 모델)부터 Built-in Layer(nn.Linear, nn.Conv2d, etc.)까지 전부 nn.Module을 상속

ds31x.tistory.com


1. Hook의 종류와 호출 위치 요약

Hook 종류 호출 시점 호출 위치 관련 메서드
Forward Hook forward 실행 후 __call__()
내부
register_forward_hook (hook)
Forward Pre-Hook forward 실행 전 __call__()
내부
register_forward_pre_hook (hook)
(Full) Backward Hook (Module) backward 중 autograd
엔진 내부
register_full_backward_hook (hook)
Backward Hook (Tensor) gradient 계산 시점 autograd
엔진 내부
tensor.register_hook (hook)

 

같은 종류의 여러 hook을 등록가능하며,
이 경우 등록된 순서대로 호출됨.


2. 각 hook의 형태 및 역할

2-1. Forward Hook

def forward_hook(module, input, output):
    print(f"[Forward Hook] {module.__class__.__name__} output:\n{output}")
  • module: hook이 붙은 nn.Module 객체
  • input: module의 입력값 (tuple)
  • output: module의 출력값 (Tensor 또는 tuple)
  • 호출 시점: __call__()내 에서 forward() 호출 이후
  • 반환값이 없으며, 이는 tensor 객체의 수정용이 아님을 의미함.

2-2. Forward Pre-Hook

def pre_hook(module, input):
    print(f"[Pre-Hook] {module.__class__.__name__} input:\n{input}")
    return input  # 수정도 가능
  • 호출 시점: __call__()내 에서 forward() 호출 이전.
  • input 을 수정하여 반환 가능함.

2-3. Full Backward Hook (Module)

def full_backward_hook(module, grad_input, grad_output):
    print(f"[Backward Hook] {module.__class__.__name__}")
    print(f"  grad_input: {grad_input}")
    print(f"  grad_output: {grad_output}")
    # grad_input 처리
    return grad_input #수정 가능.
  • 호출 시점: loss.backward() 중 autograd 엔진 내부
    • Autograd 엔진이 module.backward() 호출 후, gradient 계산이 끝난 뒤 hook이 실행
  • grad_input: 이 모듈에 대한 입력값들에 대한 gradient (backward propagation에서 이전 layer로 넘길 값)
  • grad_output: 이 모듈의 출력값들에 대한 gradient (backward propagation에서 다음 layer로부터 전달받은 값)
  • grad_input 만 수정가능함 (backward 과정에선 출력이므로..)

2-4. Backward Hook (Tensor)

def tensor_grad_hook(grad):
    print(f"[Tensor Backward Hook] grad: {grad}")
    return grad  # 수정도 가능
  • 호출 시점: 텐서에 대한 gradient 계산 시
  • retain_grad = True 시 해당 텐서가 가질 값을 parameter grad로 넘겨 받음.

이는 뒤쪽 노드로부터 backpropagaton 과정에서 해당 tensor로 "전달된 gradient"에 local gradient가 곱해진 값임에 유의할 것.

이는 backward를 호출한 텐서에 해당하는 값을 hook가 등록된 텐서에 대해 구한 gradient임.

 

 


3. 실제 동작 예제 (nn.Linear 기반)

https://gist.github.com/dsaint31x/79bca8c5a40d329a84be858c0af53d49

 

dl_pytorch_hook.ipynb

dl_pytorch_hook.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

code는 대략 다음과 같음.

import torch
import torch.nn as nn

# 1. 모델 정의
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 2)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# 2. Forward Pre-Hook 등록
def pre_hook(module, input):
    print(f"[Pre-Hook] 입력: {input}")
    return input

model.linear.register_forward_pre_hook(pre_hook)

# 3. Forward Hook 등록
def forward_hook(module, input, output):
    print(f"[Forward Hook] 출력: {output}")

model.linear.register_forward_hook(forward_hook)

# 4. Backward Hook 등록 (Module)
def backward_hook(module, grad_input, grad_output):
    print(f"[Backward Hook] grad_input: {grad_input}, grad_output: {grad_output}")

handle = model.linear.register_full_backward_hook(backward_hook)

# 5. Tensor Gradient Hook 등록
def tensor_hook(grad):
    print(f"[Tensor Hook] grad: {grad}")
    return grad

# 6. 입력 생성 및 forward/backward 실행
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], requires_grad=True)
y = model(x)
y.register_hook(tensor_hook)  # tensor에 직접 hook

loss = y.sum()
loss.backward()

handle.remove() #등록된 hook을 제거.

 

위 코드의 실행 흐름은 다음과 같음.

  1. model(x) 호출
    • __call__() 실행
      • forward_pre_hook 호출됨 (입력 관찰/수정 가능)
      • forward() 실행
      • forward_hook 호출됨 (출력 관찰 가능)
  2. loss.backward() 호출
    • autograd 엔진이 작동
      • tensor hook
      • module backward hook 순으로 호출

3-1. hook을 등록 해제하기.

register_xxx 메서드를 호출하여 등록할 때, 반환값으로 등록을 해제할 수 있는 handle 객체를 반환함.

해당 객체에서 remove() 메서드를 호출하면 등록 해제됨.


3-2. instance method 로 구현한 경우

instance method 에 대한 참고 자료는 다음과 같음 (일반적으로 method라고 하면 instance method임).

2023.08.20 - [Python] - [Python] instance methods, class methods, and static methods

 

[Python] instance methods, class methods, and static methods

Instance Methodsinstance를 통해 접근(=호출)되는 methods를 가르킴.일반적인 methods가 바로 instance methods임. method와 function의 차이점 중 하나로 애기되는"정의될 때 첫번째 parameter가 self이면 method "라는 것

ds31x.tistory.com

 

다음의 코드 참고:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 2)
        self.linear.register_forward_hook(self.my_forward_hook)

    def my_forward_hook(self, module, input, output):
        print(f"[HOOK] {module.__class__.__name__} output: {output}")

    def forward(self, x):
        return self.linear(x)

# 실행
model = MyModel()
x = torch.randn(1, 4)
_ = model(x)
  • self.my_forward_hook는 인스턴스 메서드이므로 self는 자동으로 바인딩됨.
  • hook 호출 시 인자는 (module, input, output) : 메서드에서는 self까지 총 4개의 인자 처리.
  • class method는 잘 사용하지 않으며, 주로 instance method가 이용됨.

전체 흐름 그림

model(x)
│
├── __call__()
│   ├── forward_pre_hook(module, input)
│   ├── forward()
│   └── forward_hook(module, input, output)
│
loss.backward()
│
├── tensor_hook(grad)
└── full_backward_hook(module, grad_input, grad_output)

 


요약

다음 표에서 반환값은 hook들의 반환값임:

  • register를 위해 사용되는 method들은 등록 해제를 위한 객체를 반환하는 점에 유의할 것.
  • 반환값이 없으면 사실상 forward, backward propagation 과정 중에 텐서의 수정을 못함을 의미함. 
Hook 종류 Hook 함수 arguments 반환값 호출 위치 목적
forward_hook (module, input, output) 무시됨
(없음)
__call__()
내부
(forward 후)
출력 관찰
forward_pre_hook (module, input) 수정된 입력
(optional)
__call__()
내부
(forward 전)
입력 수정
full_backward_hook (module, grad_in, grad_out) 수정된 grad_in
(optional)
autograd
엔진 내부
gradient
추적/수정
tensor hook (grad) 수정된 grad
(optional)
autograd
엔진 내부
Tensor 수준