본문 바로가기
Python

[PyTorch] autograd 심화: grad_fn 과 custom operation 만들기

by ds31x 2025. 3. 28.

1. PyTorch Autograd 메커니즘 이해

1-1. Tensor의 grad_fn과 연산 그래프 추적 방법

  • PyTorch tensor의 grad_fn은 마지막 operation만 표시하는 attribute.
  • 전체 computation graph 확인을 위한 next_functions attribute의 활용.
  • 재귀적 접근을 통한 전체 operation history의 추적.
def print_grad_graph(grad_fn, level=0):
    print(' ' * level, grad_fn)
    if hasattr(grad_fn, 'next_functions'):
        for next_func in grad_fn.next_functions:
            if next_func[0] is not None:
                print_grad_graph(next_func[0], level + 1)

# 사용 예시
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)
z = x * y
w = z.sum()

print_grad_graph(w.grad_fn)
# 출력 결과:
#  <SumBackward0 object at 0x7f8a1d5c7b80>
#   <MulBackward0 object at 0x7f8a1d5c7c10>
#    <AccumulateGrad object at 0x7f8a1d5c7ca0>
#    <AccumulateGrad object at 0x7f8a1d5c7d30>

 

Operation type까지 포함한 상세 정보 확인 방법은 다음과 같음:

def print_grad_graph_detailed(grad_fn, level=0):
    print(' ' * level, type(grad_fn).__name__, grad_fn)
    if hasattr(grad_fn, 'next_functions'):
        for next_func in grad_fn.next_functions:
            if next_func[0] is not None:
                print_grad_graph_detailed(next_func[0], level + 1)

# 사용 예시
w = z.sum()  # 위의 예제에서 계속
print_grad_graph_detailed(w.grad_fn)
# 출력 결과:
#  SumBackward0 <SumBackward0 object at 0x7f8a1d5c7b80>
#   MulBackward0 <MulBackward0 object at 0x7f8a1d5c7c10>
#    AccumulateGrad <AccumulateGrad object at 0x7f8a1d5c7ca0>
#    AccumulateGrad <AccumulateGrad object at 0x7f8a1d5c7d30>

2. Custom Autograd Function 만들기

  • 모든 differentiable operation은 torch.autograd.Function class의 상속.
  • Custom autograd function 정의 시 필수 구현 (static) method:
    1. forward(): forward computation 수행.
    2. backward(): gradient computation 구현.
import torch
from torch.autograd import Function

class CustomFunction(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = input.clone()
        # 원하는 operation 수행
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # 원하는 gradient 계산
        return grad_input

2-1. Custom Function 사용 방법

2-1-1. apply 정적 method를 통한 호출:

class MyCustomFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return x * y
        
    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        return grad_output * y, grad_output * x

# 사용 예시
custom_op = MyCustomFunction.apply
a = torch.tensor([1.0, 2.0], requires_grad=True)
b = torch.tensor([3.0, 4.0], requires_grad=True)
c = custom_op(a, b)

print("c.grad_fn:", c.grad_fn)
print("c.grad_fn.next_functions:", c.grad_fn.next_functions)
# 출력 결과:
# c.grad_fn: <MyCustomFunctionBackward object at 0x7f8a1d5c7e50>
# c.grad_fn.next_functions: ((<AccumulateGrad object at 0x7f8a1d5c7ee0>, 0), (<AccumulateGrad object at 0x7f8a1d5c7f70>, 0))
  • apply method의 역할:
    1. forward method 호출과 순방향 계산 수행.
    2. 계산 graph 구성 및 backward method와의 연결.
    3. 결과 tensor의 반환과 grad_fn 설정.

2-1-2. Callable 객체를 통한 보다 직관적인 사용법

  • Function class를 직접 호출 가능한 형태로 래핑하는 방법:
# 방법 1: callable class 활용
class CustomOp:
    def __call__(self, x, y):
        return MyCustomFunction.apply(x, y)

# 인스턴스 생성
custom_op = CustomOp()
result = custom_op(a, b)  # 일반 함수처럼 호출
# 방법 2: 함수 래퍼 활용 (보다 간결한 접근)
def custom_op_func(x, y):
    return MyCustomFunction.apply(x, y)

result = custom_op_func(a, b)  # 일반 함수처럼 호출
  • Callable 래퍼 사용의 장점:
    1. .apply method 직접 호출 불필요.
    2. 더 직관적인 함수 호출 구문 제공.
    3. 함수형 프로그래밍 패턴과의 일관성.

3. next_functions의 자동 상속 메커니즘

  • torch.autograd.Function 상속 시 next_functions attribute의 자동 포함.
  • User가 명시적으로 next_functions 구현 불필요.
  • PyTorch의 내부 process:
    1. Operation 결과에 해당 operation type의 grad_fn 연결.
    2. grad_fnnext_functions에 input tensor들의 grad_fn 자동 저장.

3-1. next_functions attribute의 구조

  • Type: Tuple of tuples - ((Function, int), (Function, int), ...)
  • 각 tuple의 첫 번째 요소는 input tensor의 grad_fn, 두 번째 요소는 input index.
    • Input index: 이전 operation이 여러 output을 생성하는 경우, 어떤 output이 현재 operation의 input으로 사용되었는지 지정.
    • 대부분의 operation은 하나의 output만 생성하므로 index가 0인 경우가 일반적.
  • 순서: Input tensor들이 forward 함수에 전달된 순서와 동일한 순서로 저장.
  • None 값: requires_grad=False인 tensor의 경우 (None, 0) 형태로 저장.

3-2. 다중 output을 가진 operation의 예시

import torch

# 하나의 tensor를 두 부분으로 분할
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
a, b = torch.chunk(x, 2)  # 두 개의 output 생성
c = a.sum() + b.sum()

print("c.grad_fn:", c.grad_fn)
print("c.grad_fn.next_functions:", c.grad_fn.next_functions)
# 출력 결과:
# c.grad_fn: <AddBackward0 object at 0x7f8a1d5c7ee0>
# c.grad_fn.next_functions: ((<SumBackward0 object at 0x7f8a1d5c7f70>, 0), (<SumBackward0 object at 0x7f8a1d5c7fd0>, 0))

# 각 SumBackward0의 next_functions 확인
for i, (grad_fn, idx) in enumerate(c.grad_fn.next_functions):
    print(f"grad_fn[{i}].next_functions:", grad_fn.next_functions)
# 출력 결과:
# grad_fn[0].next_functions: ((<ChunkBackward0 object at 0x7f8a1d5c7b20>, 0),)
# grad_fn[1].next_functions: ((<ChunkBackward0 object at 0x7f8a1d5c7b20>, 1),)
  • 위 예제에서 torch.chunk는 하나의 tensor를 여러 부분으로 나누는 operation.
  • torch.chunk(tensor, chunks, dim=0): 지정된 dimension을 따라 tensor를 chunks 개수만큼 균등하게 분할.
  • 분할된 각 chunk는 원본 tensor의 크기를 chunks로 나눈 크기를 가지며, 나누어 떨어지지 않는 경우 마지막 chunk가 더 작을 수 있음.
  • 기본적으로 dim=0(첫 번째 dimension)을 기준으로 분할하나, dim parameter를 통해 다른 dimension 기준 분할 가능.
  • 각 chunk에 대한 gradient 계산 시 ChunkBackward의 서로 다른 output(index 0과 1)을 참조.

4. 결론

  • 이러한 mechanism을 통한 backpropagation 시 전체 computation graph의 추적 가능.
  • Autograd system의 효율적인 gradient computation의 기반이 되어짐.