본문 바로가기
목차
카테고리 없음

[DL] PyTorch: state_dict()

by ds31x 2024. 5. 16.
728x90
반응형

PyTorch: state_dict()

torch.nn.Module 객체의 state_dict() 메서드는

  • 모델의 학습 가능한 매개변수(가중치와 바이어스)의 상태와
  • 버퍼(예: BatchNorm의 running mean과 variance 등)의 상태를 저장하는
  • collections.OrderedDict 객체를 반환.

반환된 객체는 모델의 현재 상태를 나타내며, 저장 및 로드가 가능함.

2025.04.04 - [Python] - [Py] collection.OrderedDict

 

[Py] collection.OrderedDict

OrderedDict는 삽입된 순서를 보존하는 기능을 추가한 일종의 dict임.collections 모듈에서 제공.Python 3.7부터 built-in dict도 삽입 순서를 보존하게 되었음.하지만 OrderedDict는 그 외에 몇 가지 중요한 추가

ds31x.tistory.com


주요 특징

  1. OrderedDict 형태:
    • state_dict()
    • attribute 이름을 키로 하고,
    • 그에 대응하는 torch.Tensor를 값으로 갖는
    • collections.OrderedDict 객체를 반환.
  2. 학습 가능한 매개변수:
    • state_dict()
    • torch.nn.Parameter 객체로 정의된 모든 학습 가능한 attributes를 포함.
  3. 버퍼:
    • 모델에 포함된 모든 버퍼(예: BatchNorm 계층의 running mean과 variance)도 포함.

예제 코드

다음은 state_dict()가 반환하는 객체에 대한 예제임.

import torch
import torch.nn as nn
from collections import OrderedDict

# 간단한 신경망 모델 클래스 정의
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 10차원 입력을 5차원으로 변환하는 선형 레이어 정의
        self.linear = nn.Linear(10, 5)
        # 5차원 특성에 대한 배치 정규화 레이어 정의
        self.bn = nn.BatchNorm1d(5)
        
    def forward(self, x):
        # 선형 변환 적용
        x = self.linear(x)
        # 배치 정규화 적용
        x = self.bn(x)
        return x

# 모델 인스턴스 생성
model = MyModel()

# 모델의 state_dict 획득
state_dict = model.state_dict()

# 파라미터와 버퍼 구분을 위한 출력 섹션

print("=== 파라미터 (학습 가능한 가중치) ===")
# named_parameters() 메서드를 통한 모든 파라미터 출력
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}, requires_grad={param.requires_grad}")

print("\n=== 버퍼 (학습 불가능한 상태 값) ===")
# named_buffers() 메서드를 통한 모든 버퍼 출력
for name, buf in model.named_buffers():
    print(f"{name}: {buf.shape}, requires_grad={buf.requires_grad}")

print("\n=== 전체 state_dict 내용 ===")
# 파라미터와 버퍼 키 목록 생성
param_keys = [name for name, _ in model.named_parameters()]
buffer_keys = [name for name, _ in model.named_buffers()]

# state_dict의 각 항목에 대한 유형 구분
for key, value in state_dict.items():
    if key in param_keys:
        print(f"{key}: {value.shape} (파라미터)")
    elif key in buffer_keys:
        print(f"{key}: {value.shape} (버퍼)")
    else:
        print(f"{key}: {value.shape} (기타)")

# 예시 출력 설명:
# - linear.weight, linear.bias: 선형 레이어의 학습 가능한 파라미터
# - bn.weight, bn.bias: 배치 정규화의 학습 가능한 파라미터
# - bn.running_mean, bn.running_var: 배치 정규화의 통계적 버퍼 값
# - bn.num_batches_tracked: 배치 정규화의 추적용 버퍼 값

https://gist.github.com/dsaint31x/3e0f42c7ec25380dd947776f589522ac

 

dl_torch_state_dict_keep_vars.ipynb

dl_torch_state_dict_keep_vars.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com


메서드 parameters:

state_dict(destination=None, prefix='', keep_vars=False)->OrderedDict

  • keep_vars 는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정
    • keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
      • value 가 파라메터인 경우엔 parameter 객체로 얻어지고,
      • value 가 버퍼인 경우엔 tensor 로 얻어짐.
    • keep_vars=True 인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.
      • 모델 디버깅: 모델 상태를 조사하고 특정 parameterbuffer의 값을 변경을 가능케 함.
      • 모델 커스터마이징: 모델을 불러온 후 특정 parameter buffer의 값을 변경해야 하는 경우 유용
      • 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 사용.
    • 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
    • 저장의 용도로는 keep_vars=False를 사용하는 게 좋음.
  • destination
    • 기본값: None
    • 역할: 상태 딕셔너리를 저장할 대상 OrderedDict
    • 설명:
      • 이 매개변수가 지정되면, 메서드는 새 OrderedDict 객체를 생성하지 않고 이 딕셔너리에 상태를 추가함.
      • 일반적으로 None으로 두면 자동으로 새 OrderedDict 객체를 생성.
    • 활용 예: 여러 모델의 상태를 하나의 딕셔너리에 합칠 때 유용.
  • prefix
    • 기본값: '' (빈 문자열)
    • 역할: 모든 파라미터 키의 접두사
    • 설명: 상태 딕셔너리의 모든 키 앞에 붙는 문자열로, 일반적으로 모듈 계층 구조를 나타낼 때 사용됨.
    • 활용 예: 서브모듈의 상태를 구분하기 위해 'encoder.', 'decoder.' 등의 접두사를 사용할 수 있음.
더보기

destination과 prefix를 사용한 예제코드

import torch
import torch.nn as nn
from collections import OrderedDict

# 간단한 모델 정의
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5)
        self.bn = nn.BatchNorm1d(5)
        
    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x)
        return x

# 더 복잡한 모델 정의 (서브모듈 포함)
class ComplexModel(nn.Module):
    def __init__(self):
        super(ComplexModel, self).__init__()
        self.encoder = SimpleModel()
        self.decoder = nn.Linear(5, 2)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# 모델 인스턴스 생성
simple_model = SimpleModel()
complex_model = ComplexModel()

# 예시 1: 기본 사용
print("=== 예시 1: 기본 사용 ===")
state_dict = simple_model.state_dict()
print("기본 state_dict 타입:", type(state_dict))
print("기본 state_dict 내용:")
for key, value in state_dict.items():
    print(f"  {key}: {value.shape}")

# 예시 2: prefix 매개변수 사용
print("\n=== 예시 2: prefix 매개변수 사용 ===")
prefixed_dict = simple_model.state_dict(prefix='model1.')
print("prefix 적용된 state_dict 내용:")
for key, value in prefixed_dict.items():
    print(f"  {key}: {value.shape}")

# 예시 3: destination 매개변수 사용
print("\n=== 예시 3: destination 매개변수 사용 ===")
combined_dict = OrderedDict()
simple_model.state_dict(destination=combined_dict, prefix='model1.')
print("첫 번째 모델 추가 후 destination 내용:")
for key, value in combined_dict.items():
    print(f"  {key}: {value.shape}")

# 두 번째 모델도 같은 destination에 추가
simple_model.state_dict(destination=combined_dict, prefix='model2.')
print("\n두 번째 모델 추가 후 destination 내용:")
for key, value in combined_dict.items():
    print(f"  {key}: {value.shape}")

# 예시 4: keep_vars 매개변수 사용
print("\n=== 예시 4: keep_vars 매개변수 사용 ===")
vars_dict = simple_model.state_dict(keep_vars=True)
print("keep_vars=True 적용된 state_dict 내용:")
for key, tensor in vars_dict.items():
    print(f"  {key}: 형태={tensor.shape}, 텐서 객체={type(tensor)}, requires_grad={tensor.requires_grad}")

# 예시 5: 복잡한 모델 계층 구조
print("\n=== 예시 5: 복잡한 모델 계층 구조 ===")
complex_dict = complex_model.state_dict()
print("복잡한 모델의 state_dict 내용:")
for key, value in complex_dict.items():
    print(f"  {key}: {value.shape}")

 

실행결과

=== 예시 1: 기본 사용 ===
기본 state_dict 타입: <class 'collections.OrderedDict'>
기본 state_dict 내용:
  linear.weight: torch.Size([5, 10])
  linear.bias: torch.Size([5])
  bn.weight: torch.Size([5])
  bn.bias: torch.Size([5])
  bn.running_mean: torch.Size([5])
  bn.running_var: torch.Size([5])
  bn.num_batches_tracked: torch.Size([])

=== 예시 2: prefix 매개변수 사용 ===
prefix 적용된 state_dict 내용:
  model1.linear.weight: torch.Size([5, 10])
  model1.linear.bias: torch.Size([5])
  model1.bn.weight: torch.Size([5])
  model1.bn.bias: torch.Size([5])
  model1.bn.running_mean: torch.Size([5])
  model1.bn.running_var: torch.Size([5])
  model1.bn.num_batches_tracked: torch.Size([])

=== 예시 3: destination 매개변수 사용 ===
첫 번째 모델 추가 후 destination 내용:
  model1.linear.weight: torch.Size([5, 10])
  model1.linear.bias: torch.Size([5])
  model1.bn.weight: torch.Size([5])
  model1.bn.bias: torch.Size([5])
  model1.bn.running_mean: torch.Size([5])
  model1.bn.running_var: torch.Size([5])
  model1.bn.num_batches_tracked: torch.Size([])

두 번째 모델 추가 후 destination 내용:
  model1.linear.weight: torch.Size([5, 10])
  model1.linear.bias: torch.Size([5])
  model1.bn.weight: torch.Size([5])
  model1.bn.bias: torch.Size([5])
  model1.bn.running_mean: torch.Size([5])
  model1.bn.running_var: torch.Size([5])
  model1.bn.num_batches_tracked: torch.Size([])
  model2.linear.weight: torch.Size([5, 10])
  model2.linear.bias: torch.Size([5])
  model2.bn.weight: torch.Size([5])
  model2.bn.bias: torch.Size([5])
  model2.bn.running_mean: torch.Size([5])
  model2.bn.running_var: torch.Size([5])
  model2.bn.num_batches_tracked: torch.Size([])

=== 예시 4: keep_vars 매개변수 사용 ===
keep_vars=True 적용된 state_dict 내용:
  linear.weight: 형태=torch.Size([5, 10]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
  linear.bias: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
  bn.weight: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
  bn.bias: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
  bn.running_mean: 형태=torch.Size([5]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False
  bn.running_var: 형태=torch.Size([5]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False
  bn.num_batches_tracked: 형태=torch.Size([]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False

=== 예시 5: 복잡한 모델 계층 구조 ===
복잡한 모델의 state_dict 내용:
  encoder.linear.weight: torch.Size([5, 10])
  encoder.linear.bias: torch.Size([5])
  encoder.bn.weight: torch.Size([5])
  encoder.bn.bias: torch.Size([5])
  encoder.bn.running_mean: torch.Size([5])
  encoder.bn.running_var: torch.Size([5])
  encoder.bn.num_batches_tracked: torch.Size([])
  decoder.weight: torch.Size([2, 5])
  decoder.bias: torch.Size([2])

 

  • 기본 사용: state_dict()는 모든 파라미터와 버퍼를 OrderedDict 형태로 반환.
  • prefix 사용: 모든 키 앞에 지정된 접두사가 붙습니다. 이는 여러 모델의 상태를 구분할 때 유용함.
  • destination 사용:
    • 같은 OrderedDict에 여러 모델의 상태를 합칠 수 있음.
    • 이는 앙상블 모델이나 복합 모델 저장 시 유용함.
  • keep_vars 사용:
    • True로 설정하면 값 대신 실제 텐서 객체가 저장됨.
    • 파라미터(nn.Parameter)와 버퍼(torch.Tensor)의 구분이 명확하게 보임.
  • 계층적 구조: 복잡한 모델에서는 서브모듈의 이름이 자동으로 키의 접두사로 사용됨.

 


출력 예시

=== 파라미터 (학습 가능한 가중치) ===
linear.weight: torch.Size([5, 10]), requires_grad=True
linear.bias: torch.Size([5]), requires_grad=True
bn.weight: torch.Size([5]), requires_grad=True
bn.bias: torch.Size([5]), requires_grad=True

=== 버퍼 (학습 불가능한 상태 값) ===
bn.running_mean: torch.Size([5]), requires_grad=False
bn.running_var: torch.Size([5]), requires_grad=False
bn.num_batches_tracked: torch.Size([]), requires_grad=False

=== 전체 state_dict 내용 ===
linear.weight: torch.Size([5, 10]) (파라미터)
linear.bias: torch.Size([5]) (파라미터)
bn.weight: torch.Size([5]) (파라미터)
bn.bias: torch.Size([5]) (파라미터)
bn.running_mean: torch.Size([5]) (버퍼)
bn.running_var: torch.Size([5]) (버퍼)
bn.num_batches_tracked: torch.Size([]) (버퍼)

설명

  • linear.weight: nn.Linear 계층의 가중치 파라미터
  • linear.bias: nn.Linear 계층의 바이어스 파라미터
  • bn.weight: nn.BatchNorm1d 계층의 가중치 파라미터
  • bn.bias: nn.BatchNorm1d 계층의 바이어스 파라미터
  • bn.running_mean: nn.BatchNorm1d 계층의 running mean 버퍼
  • bn.running_var: nn.BatchNorm1d 계층의 running variance 버퍼
  • bn.num_batches_tracked: nn.BatchNorm1d 계층에서 배치의 수를 추적하는 버퍼

기타?

일반적으로 PyTorch 모델의 state_dict에 포함된 텐서는 파라미터(named_parameters())와 버퍼(named_buffers())의 두 가지 카테고리로 구분되며, 이 두 가지에 속하지 않는 "기타" 항목은 실제로는 거의 발생하지 않음.

 

위의 코드에서 "기타" 카테고리를 포함한 이유는 다음과 같은 특별한 경우의 값들을 다루기 위해서임.

  1. 커스텀 등록된 상태 값: register_state_dict_hook을 사용하여 사용자가 임의로 state_dict에 추가한 값.
  2. 미래 버전 호환성: PyTorch의 향후 버전에서 새로운 유형의 상태 값이 추가될 가능성에 대비한 값.
  3. 외부 모듈 또는 확장 기능: 일부 PyTorch 확장 또는 외부 라이브러리가 표준 파라미터나 버퍼가 아닌 다른 형태의 상태를 추가할 수도 있음.
  4. 모델 양자화 관련 정보: 양자화된 모델의 경우 양자화 스케일이나 제로 포인트와 같은 추가 정보가 저장될 수 있음.

일반적인 PyTorch 모델(예: 제공된 MyModel 클래스)에서는 모든 상태가 parameter나 buffer로 분류되므로 "기타" 카테고리에 속하는 항목은 출력되지 않음.

그러나 확장성을 위해 이 둘에 속하지 않는 값이 있을 수 있음.

더보기

기타 항목 추가하기

  • 아래의 코드는 state_dict 메서드를 override하여 커스텀 항목 'custom_item'을 추가함.
  • 이 항목은 named_parameters()named_buffers()에 포함되지 않으므로 "기타" 카테고리로 분류됨.
import torch
import torch.nn as nn
from collections import OrderedDict

# 간단한 신경망 모델 클래스 정의
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 10차원 입력을 5차원으로 변환하는 선형 레이어 정의
        self.linear = nn.Linear(10, 5)
        # 5차원 특성에 대한 배치 정규화 레이어 정의
        self.bn = nn.BatchNorm1d(5)
        
        # 일반적인 속성으로 텐서 추가 (state_dict에 자동으로 포함되지 않음)
        self.custom_tensor = torch.ones(3, 4)
        
    def forward(self, x):
        # 선형 변환 적용
        x = self.linear(x)
        # 배치 정규화 적용
        x = self.bn(x)
        return x
    
    # state_dict 메서드 오버라이드
    def state_dict(self, *args, **kwargs):
        # 기본 state_dict 획득
        state_dict_orig = super().state_dict(*args, **kwargs)
        
        # 커스텀 state_dict 생성 및 원본 복사
        state_dict_new = OrderedDict(state_dict_orig)
        
        # 커스텀 항목 추가
        state_dict_new['custom_item'] = self.custom_tensor
        
        return state_dict_new

# 모델 인스턴스 생성
model = MyModel()

# 모델의 state_dict 획득
state_dict = model.state_dict()

# 파라미터와 버퍼 구분을 위한 출력 섹션
print("=== 파라미터 (학습 가능한 가중치) ===")
# named_parameters() 메서드를 통한 모든 파라미터 출력
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}, requires_grad={param.requires_grad}")

print("\n=== 버퍼 (학습 불가능한 상태 값) ===")
# named_buffers() 메서드를 통한 모든 버퍼 출력
for name, buf in model.named_buffers():
    print(f"{name}: {buf.shape}, requires_grad={buf.requires_grad}")

print("\n=== 전체 state_dict 내용 ===")
# 파라미터와 버퍼 키 목록 생성
param_keys = [name for name, _ in model.named_parameters()]
buffer_keys = [name for name, _ in model.named_buffers()]

# state_dict의 각 항목에 대한 유형 구분
for key, value in state_dict.items():
    if key in param_keys:
        print(f"{key}: {value.shape} (파라미터)")
    elif key in buffer_keys:
        print(f"{key}: {value.shape} (버퍼)")
    else:
        print(f"{key}: {value.shape} (기타)")

 

결과는 다음과 같음.

=== 파라미터 (학습 가능한 가중치) ===
linear.weight: torch.Size([5, 10]), requires_grad=True
linear.bias: torch.Size([5]), requires_grad=True
bn.weight: torch.Size([5]), requires_grad=True
bn.bias: torch.Size([5]), requires_grad=True

=== 버퍼 (학습 불가능한 상태 값) ===
bn.running_mean: torch.Size([5]), requires_grad=False
bn.running_var: torch.Size([5]), requires_grad=False
bn.num_batches_tracked: torch.Size([]), requires_grad=False

=== 전체 state_dict 내용 ===
linear.weight: torch.Size([5, 10]) (파라미터)
linear.bias: torch.Size([5]) (파라미터)
bn.weight: torch.Size([5]) (파라미터)
bn.bias: torch.Size([5]) (파라미터)
bn.running_mean: torch.Size([5]) (버퍼)
bn.running_var: torch.Size([5]) (버퍼)
bn.num_batches_tracked: torch.Size([]) (버퍼)
custom_item: torch.Size([3, 4]) (기타)

활용 방법

모델 저장:

state_dict를 파일에 저장하여 나중에 모델을 복원할 수 있음.

torch.save(model.state_dict(), 'model_state.pth')

 

모델 로드:

저장된 state_dict를 로드하여 모델을 복원.

model = MyModel() 
model.load_state_dict(torch.load('model_state.pth')) 
model.eval() # 평가 모드로 전환 (선택 사항)`

 

파라미터 업데이트:

state_dict를 사용하여 모델의 특정 파라미터를 업데이트할 수 있음.

state_dict['linear.weight'] = torch.ones_like(state_dict['linear.weight']) 
model.load_state_dict(state_dict)

결론

state_dict()는 PyTorch에서 모델의 학습 가능한 매개변수와 버퍼의 상태를 관리하는 데 중요한 역할을 당담함.
이를 통해 모델을 쉽게 저장하고 로드할 수 있으며, 모델 파라미터의 직접적인 접근 및 수정도 가능합니다.

728x90