
1. PyTorch Autograd 메커니즘 이해
1-1. Tensor의 grad_fn과 연산 그래프 추적 방법
- PyTorch tensor의 grad_fn은 마지막 operation만 표시하는 attribute.
- 전체 computation graph 확인을 위한 next_functions attribute의 활용.
- 재귀적 접근을 통한 전체 operation history의 추적.
def print_grad_graph(grad_fn, level=0):
print(' ' * level, grad_fn)
if hasattr(grad_fn, 'next_functions'):
for next_func in grad_fn.next_functions:
if next_func[0] is not None:
print_grad_graph(next_func[0], level + 1)
# 사용 예시
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)
z = x * y
w = z.sum()
print_grad_graph(w.grad_fn)
# 출력 결과:
# <SumBackward0 object at 0x7f8a1d5c7b80>
# <MulBackward0 object at 0x7f8a1d5c7c10>
# <AccumulateGrad object at 0x7f8a1d5c7ca0>
# <AccumulateGrad object at 0x7f8a1d5c7d30>
Operation type까지 포함한 상세 정보 확인 방법은 다음과 같음:
def print_grad_graph_detailed(grad_fn, level=0):
print(' ' * level, type(grad_fn).__name__, grad_fn)
if hasattr(grad_fn, 'next_functions'):
for next_func in grad_fn.next_functions:
if next_func[0] is not None:
print_grad_graph_detailed(next_func[0], level + 1)
# 사용 예시
w = z.sum() # 위의 예제에서 계속
print_grad_graph_detailed(w.grad_fn)
# 출력 결과:
# SumBackward0 <SumBackward0 object at 0x7f8a1d5c7b80>
# MulBackward0 <MulBackward0 object at 0x7f8a1d5c7c10>
# AccumulateGrad <AccumulateGrad object at 0x7f8a1d5c7ca0>
# AccumulateGrad <AccumulateGrad object at 0x7f8a1d5c7d30>
2. Custom Autograd Function 만들기
- 모든 differentiable operation은
torch.autograd.Function
class의 상속. - Custom autograd function 정의 시 필수 구현 (static) method:
forward()
: forward computation 수행.backward()
: gradient computation 구현.
import torch
from torch.autograd import Function
class CustomFunction(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = input.clone()
# 원하는 operation 수행
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
# 원하는 gradient 계산
return grad_input
2-1. Custom Function 사용 방법
2-1-1. apply
정적 method를 통한 호출:
class MyCustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x * y
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return grad_output * y, grad_output * x
# 사용 예시
custom_op = MyCustomFunction.apply
a = torch.tensor([1.0, 2.0], requires_grad=True)
b = torch.tensor([3.0, 4.0], requires_grad=True)
c = custom_op(a, b)
print("c.grad_fn:", c.grad_fn)
print("c.grad_fn.next_functions:", c.grad_fn.next_functions)
# 출력 결과:
# c.grad_fn: <MyCustomFunctionBackward object at 0x7f8a1d5c7e50>
# c.grad_fn.next_functions: ((<AccumulateGrad object at 0x7f8a1d5c7ee0>, 0), (<AccumulateGrad object at 0x7f8a1d5c7f70>, 0))
apply
method의 역할:forward
method 호출과 순방향 계산 수행.- 계산 graph 구성 및
backward
method와의 연결. - 결과 tensor의 반환과 grad_fn 설정.
2-1-2. Callable 객체를 통한 보다 직관적인 사용법
- Function class를 직접 호출 가능한 형태로 래핑하는 방법:
# 방법 1: callable class 활용
class CustomOp:
def __call__(self, x, y):
return MyCustomFunction.apply(x, y)
# 인스턴스 생성
custom_op = CustomOp()
result = custom_op(a, b) # 일반 함수처럼 호출
# 방법 2: 함수 래퍼 활용 (보다 간결한 접근)
def custom_op_func(x, y):
return MyCustomFunction.apply(x, y)
result = custom_op_func(a, b) # 일반 함수처럼 호출
- Callable 래퍼 사용의 장점:
.apply
method 직접 호출 불필요.- 더 직관적인 함수 호출 구문 제공.
- 함수형 프로그래밍 패턴과의 일관성.
3. next_functions의 자동 상속 메커니즘
torch.autograd.Function
상속 시 next_functions attribute의 자동 포함.- User가 명시적으로 next_functions 구현 불필요.
- PyTorch의 내부 process:
- Operation 결과에 해당 operation type의 grad_fn 연결.
- grad_fn의 next_functions에 input tensor들의 grad_fn 자동 저장.
3-1. next_functions attribute의 구조
- Type: Tuple of tuples -
((Function, int), (Function, int), ...)
- 각 tuple의 첫 번째 요소는 input tensor의 grad_fn, 두 번째 요소는 input index.
- Input index: 이전 operation이 여러 output을 생성하는 경우, 어떤 output이 현재 operation의 input으로 사용되었는지 지정.
- 대부분의 operation은 하나의 output만 생성하므로 index가 0인 경우가 일반적.
- 순서: Input tensor들이 forward 함수에 전달된 순서와 동일한 순서로 저장.
- None 값: requires_grad=False인 tensor의 경우 (None, 0) 형태로 저장.
3-2. 다중 output을 가진 operation의 예시
import torch
# 하나의 tensor를 두 부분으로 분할
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
a, b = torch.chunk(x, 2) # 두 개의 output 생성
c = a.sum() + b.sum()
print("c.grad_fn:", c.grad_fn)
print("c.grad_fn.next_functions:", c.grad_fn.next_functions)
# 출력 결과:
# c.grad_fn: <AddBackward0 object at 0x7f8a1d5c7ee0>
# c.grad_fn.next_functions: ((<SumBackward0 object at 0x7f8a1d5c7f70>, 0), (<SumBackward0 object at 0x7f8a1d5c7fd0>, 0))
# 각 SumBackward0의 next_functions 확인
for i, (grad_fn, idx) in enumerate(c.grad_fn.next_functions):
print(f"grad_fn[{i}].next_functions:", grad_fn.next_functions)
# 출력 결과:
# grad_fn[0].next_functions: ((<ChunkBackward0 object at 0x7f8a1d5c7b20>, 0),)
# grad_fn[1].next_functions: ((<ChunkBackward0 object at 0x7f8a1d5c7b20>, 1),)
- 위 예제에서
torch.chunk
는 하나의 tensor를 여러 부분으로 나누는 operation. torch.chunk(tensor, chunks, dim=0)
: 지정된 dimension을 따라 tensor를 chunks 개수만큼 균등하게 분할.- 분할된 각 chunk는 원본 tensor의 크기를 chunks로 나눈 크기를 가지며, 나누어 떨어지지 않는 경우 마지막 chunk가 더 작을 수 있음.
- 기본적으로 dim=0(첫 번째 dimension)을 기준으로 분할하나, dim parameter를 통해 다른 dimension 기준 분할 가능.
- 각 chunk에 대한 gradient 계산 시 ChunkBackward의 서로 다른 output(index 0과 1)을 참조.
4. 결론
- 이러한 mechanism을 통한 backpropagation 시 전체 computation graph의 추적 가능.
- Autograd system의 효율적인 gradient computation의 기반이 되어짐.
'Python' 카테고리의 다른 글
[Py] collections 모듈 (summary) - 작성중 (0) | 2025.04.04 |
---|---|
[Py] print 함수 (0) | 2025.04.02 |
[Py] Bitwise Operator (0) | 2025.03.26 |
[Py] 객체(object)에 대한 정보 확인하기 (0) | 2025.03.19 |
[Ex] Numeric Datatype 다음의 연산들의 결과를 구해보고 그 이유를 설명해보라. (0) | 2025.03.17 |