본문 바로가기
카테고리 없음

[DL] PyTorch: state_dict()

by ds31x 2024. 5. 16.

PyTorch: state_dict()

torch.nn.Module 객체의 state_dict() 메서드는

  • 모델의 학습 가능한 매개변수(가중치와 바이어스)의 상태와
  • 버퍼(예: BatchNorm의 running mean과 variance 등)의 상태를 저장하는
  • collections.OrderedDict 객체를 반환.

반환된 객체는 모델의 현재 상태를 나타내며, 저장 및 로드가 가능함.


주요 특징

  1. OrderedDict 형태:
    • state_dict()
    • attribute 이름을 키로 하고,
    • 그에 대응하는 torch.Tensor를 값으로 갖는
    • collections.OrderedDict 객체를 반환.
  2. 학습 가능한 매개변수:
    • state_dict()
    • torch.nn.Parameter 객체로 정의된 모든 학습 가능한 attributes를 포함.
  3. 버퍼:
    • 모델에 포함된 모든 버퍼(예: BatchNorm 계층의 running mean과 variance)도 포함.

예제 코드

다음은 state_dict()가 반환하는 객체에 대한 예제임.

import torch
import torch.nn as nn
from collections import OrderedDict

# 간단한 모델 정의
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 5)
        self.bn = nn.BatchNorm1d(5)
        # self.sub = nn.Sequential(
        #     OrderedDict({
        #         'ds00_layer': nn.Linear(5,5),
        #         'ds00_bn':nn.BatchNorm1d(5),
        #         'ds00_act': nn.ReLU(),
        #         'ds01_layer': nn.Linear(5,5),
        #     })
        # )

    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x)
        return x

# 모델 인스턴스 생성
model = MyModel()

# 모델의 state_dict 가져오기
state_dict = model.state_dict()

# state_dict 내용 출력
for key, value in state_dict.items():
    print(f"{key}: {value.shape}")

https://gist.github.com/dsaint31x/3e0f42c7ec25380dd947776f589522ac

 

dl_torch_state_dict_keep_vars.ipynb

dl_torch_state_dict_keep_vars.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 

메서드 parameters: state_dict(desitnation=None, prefix='', keep_vars=False)

  • keep_vars 는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정
    • keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
      • value 가 파라메터인 경우엔 Parameter 로 얻어지고,
      • value 가 버퍼인 경우엔 Tensor 로 얻어짐.
    • keep_vars=True 인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.
      • 모델 디버깅: 모델 상태를 조사하고 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
      • 모델 커스터마이징: 모델을 불러온 후 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
      • 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 유용
    • 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
    • 저장의 용도로는 keep_vars=False를 사용하는 게 좋음.

출력 예시

linear.weight: torch.Size([5, 10])
linear.bias: torch.Size([5])
bn.weight: torch.Size([5])
bn.bias: torch.Size([5])
bn.running_mean: torch.Size([5])
bn.running_var: torch.Size([5])
bn.num_batches_tracked: torch.Size([])

설명

  • linear.weight: nn.Linear 계층의 가중치 파라미터
  • linear.bias: nn.Linear 계층의 바이어스 파라미터
  • bn.weight: nn.BatchNorm1d 계층의 가중치 파라미터
  • bn.bias: nn.BatchNorm1d 계층의 바이어스 파라미터
  • bn.running_mean: nn.BatchNorm1d 계층의 running mean 버퍼
  • bn.running_var: nn.BatchNorm1d 계층의 running variance 버퍼
  • bn.num_batches_tracked: nn.BatchNorm1d 계층에서 배치의 수를 추적하는 버퍼

활용 방법

모델 저장:

state_dict를 파일에 저장하여 나중에 모델을 복원할 수 있음.

torch.save(model.state_dict(), 'model_state.pth')

 

모델 로드:

저장된 state_dict를 로드하여 모델을 복원.

model = MyModel() 
model.load_state_dict(torch.load('model_state.pth')) 
model.eval() # 평가 모드로 전환 (선택 사항)`

 

파라미터 업데이트:

state_dict를 사용하여 모델의 특정 파라미터를 업데이트할 수 있음.

state_dict['linear.weight'] = torch.ones_like(state_dict['linear.weight']) 
model.load_state_dict(state_dict)

결론

state_dict()는 PyTorch에서 모델의 학습 가능한 매개변수와 버퍼의 상태를 관리하는 데 중요한 역할을 당담함.
이를 통해 모델을 쉽게 저장하고 로드할 수 있으며, 모델 파라미터의 직접적인 접근 및 수정도 가능합니다.

728x90