PyTorch의 hook은
- Neural Network 내부의 계산 과정을 관찰하거나,
- 특정 시점에서 개입할 수 있도록 해주는 기능 (사실은 function 또는 instance method임).
을 제공함.
이를 통해 forward 중간 출력 및 backward 에서의 gradient, 또는 입력 값 자체를 가로채어 확인하거나 수정할 수 있음.
Hook이 제대로 동작하기 위해선
forward를 직접 호출해선 안됨.
관련 내용은 다음 URL에서 __call__ 과 forward 관련부분 읽어볼것:
2024.04.12 - [Python] - [PyTorch] Custom Model 과 torch.nn.Module의 메서드들.
[PyTorch] Custom Model 과 torch.nn.Module의 메서드들.
Custom Model 만들기0. nn.Module torch.nn.Module은 PyTorch에서 모든 신경망 모델과 계층의 기반이 되는 클래스임.Custom Model (사용자 정의 모델)부터 Built-in Layer(nn.Linear, nn.Conv2d, etc.)까지 전부 nn.Module을 상속
ds31x.tistory.com
1. Hook의 종류와 호출 위치 요약
Hook 종류 | 호출 시점 | 호출 위치 | 관련 메서드 |
Forward Hook | forward 실행 후 | __call__() 내부 |
register_forward_hook (hook) |
Forward Pre-Hook | forward 실행 전 | __call__() 내부 |
register_forward_pre_hook (hook) |
(Full) Backward Hook (Module) | backward 중 | autograd 엔진 내부 |
register_full_backward_hook (hook) |
Backward Hook (Tensor) | gradient 계산 시점 | autograd 엔진 내부 |
tensor.register_hook (hook) |
같은 종류의 여러 hook을 등록가능하며,
이 경우 등록된 순서대로 호출됨.

2. 각 hook의 형태 및 역할
2-1. Forward Hook
def forward_hook(module, input, output):
print(f"[Forward Hook] {module.__class__.__name__} output:\n{output}")
module
: hook이 붙은 nn.Module 객체input
:module
의 입력값 (tuple
)output
:module
의 출력값 (Tensor
또는tuple
)- 호출 시점:
__call__()
내 에서forward()
호출 이후 - 반환값이 없으며, 이는 tensor 객체의 수정용이 아님을 의미함.
2-2. Forward Pre-Hook
def pre_hook(module, input):
print(f"[Pre-Hook] {module.__class__.__name__} input:\n{input}")
return input # 수정도 가능
- 호출 시점:
__call__()
내 에서forward()
호출 이전. input
을 수정하여 반환 가능함.
2-3. Full Backward Hook (Module)
def full_backward_hook(module, grad_input, grad_output):
print(f"[Backward Hook] {module.__class__.__name__}")
print(f" grad_input: {grad_input}")
print(f" grad_output: {grad_output}")
# grad_input 처리
return grad_input #수정 가능.
- 호출 시점:
loss.backward()
중 autograd 엔진 내부- Autograd 엔진이
module.backward()
호출 후, gradient 계산이 끝난 뒤 hook이 실행
- Autograd 엔진이
grad_input
: 이 모듈에 대한 입력값들에 대한 gradient (backward propagation에서 이전 layer로 넘길 값)grad_output
: 이 모듈의 출력값들에 대한 gradient (backward propagation에서 다음 layer로부터 전달받은 값)grad_input
만 수정가능함 (backward 과정에선 출력이므로..)
2-4. Backward Hook (Tensor)
def tensor_grad_hook(grad):
print(f"[Tensor Backward Hook] grad: {grad}")
return grad # 수정도 가능
- 호출 시점: 텐서에 대한 gradient 계산 시
retain_grad
= True 시 해당 텐서가 가질 값을 parametergrad
로 넘겨 받음.
이는 뒤쪽 노드로부터 backpropagaton 과정에서 해당 tensor로 "전달된 gradient"에 local gradient가 곱해진 값임에 유의할 것.
이는 backward를 호출한 텐서에 해당하는 값을 hook가 등록된 텐서에 대해 구한 gradient임.
3. 실제 동작 예제 (nn.Linear
기반)
https://gist.github.com/dsaint31x/79bca8c5a40d329a84be858c0af53d49
dl_pytorch_hook.ipynb
dl_pytorch_hook.ipynb. GitHub Gist: instantly share code, notes, and snippets.
gist.github.com
code는 대략 다음과 같음.
import torch
import torch.nn as nn
# 1. 모델 정의
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
# 2. Forward Pre-Hook 등록
def pre_hook(module, input):
print(f"[Pre-Hook] 입력: {input}")
return input
model.linear.register_forward_pre_hook(pre_hook)
# 3. Forward Hook 등록
def forward_hook(module, input, output):
print(f"[Forward Hook] 출력: {output}")
model.linear.register_forward_hook(forward_hook)
# 4. Backward Hook 등록 (Module)
def backward_hook(module, grad_input, grad_output):
print(f"[Backward Hook] grad_input: {grad_input}, grad_output: {grad_output}")
handle = model.linear.register_full_backward_hook(backward_hook)
# 5. Tensor Gradient Hook 등록
def tensor_hook(grad):
print(f"[Tensor Hook] grad: {grad}")
return grad
# 6. 입력 생성 및 forward/backward 실행
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], requires_grad=True)
y = model(x)
y.register_hook(tensor_hook) # tensor에 직접 hook
loss = y.sum()
loss.backward()
handle.remove() #등록된 hook을 제거.
위 코드의 실행 흐름은 다음과 같음.
model(x)
호출__call__()
실행forward_pre_hook
호출됨 (입력 관찰/수정 가능)forward()
실행forward_hook
호출됨 (출력 관찰 가능)
loss.backward()
호출- autograd 엔진이 작동
- tensor hook
- module backward hook 순으로 호출
- autograd 엔진이 작동
3-1. hook을 등록 해제하기.
register_xxx 메서드를 호출하여 등록할 때, 반환값으로 등록을 해제할 수 있는 handle 객체를 반환함.
해당 객체에서 remove() 메서드를 호출하면 등록 해제됨.
3-2. instance method 로 구현한 경우
instance method 에 대한 참고 자료는 다음과 같음 (일반적으로 method라고 하면 instance method임).
2023.08.20 - [Python] - [Python] instance methods, class methods, and static methods
[Python] instance methods, class methods, and static methods
Instance Methodsinstance를 통해 접근(=호출)되는 methods를 가르킴.일반적인 methods가 바로 instance methods임. method와 function의 차이점 중 하나로 애기되는"정의될 때 첫번째 parameter가 self이면 method "라는 것
ds31x.tistory.com
다음의 코드 참고:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)
self.linear.register_forward_hook(self.my_forward_hook)
def my_forward_hook(self, module, input, output):
print(f"[HOOK] {module.__class__.__name__} output: {output}")
def forward(self, x):
return self.linear(x)
# 실행
model = MyModel()
x = torch.randn(1, 4)
_ = model(x)
self.my_forward_hook
는 인스턴스 메서드이므로self
는 자동으로 바인딩됨.- hook 호출 시 인자는
(module, input, output)
: 메서드에서는self
까지 총 4개의 인자 처리. - class method는 잘 사용하지 않으며, 주로 instance method가 이용됨.
전체 흐름 그림
model(x)
│
├── __call__()
│ ├── forward_pre_hook(module, input)
│ ├── forward()
│ └── forward_hook(module, input, output)
│
loss.backward()
│
├── tensor_hook(grad)
└── full_backward_hook(module, grad_input, grad_output)
요약
다음 표에서 반환값은 hook들의 반환값임:
- register를 위해 사용되는 method들은 등록 해제를 위한 객체를 반환하는 점에 유의할 것.
- 반환값이 없으면 사실상 forward, backward propagation 과정 중에 텐서의 수정을 못함을 의미함.
Hook 종류 | Hook 함수 arguments | 반환값 | 호출 위치 | 목적 |
forward_hook |
(module, input, output) |
무시됨 (없음) |
__call__() 내부 (forward 후) |
출력 관찰 |
forward_pre_hook |
(module, input) |
수정된 입력 (optional) |
__call__() 내부 (forward 전) |
입력 수정 |
full_backward_hook |
(module, grad_in, grad_out) |
수정된 grad_in (optional) |
autograd 엔진 내부 |
gradient |
tensor hook |
(grad) |
수정된 grad (optional) |
autograd 엔진 내부 |
Tensor 수준 |
'Python' 카테고리의 다른 글
[Programming] Control Flow 와 Control Structure (1) | 2025.04.23 |
---|---|
[Py] import 의 종류. (0) | 2025.04.18 |
[DL] torch.nn.Linear 에 대하여 (1) | 2025.04.10 |
[PyTorch] torch.save 와 torch.load - tensor 위주 (0) | 2025.04.08 |
[Py] 연습문제-carriage return + time.sleep (0) | 2025.04.07 |