PyTorch: state_dict()
torch.nn.Module
객체의 state_dict()
메서드는
- 모델의 학습 가능한 매개변수(가중치와 바이어스)의 상태와
- 버퍼(예: BatchNorm의 running mean과 variance 등)의 상태를 저장하는
collections.OrderedDict
객체를 반환.
반환된 객체는 모델의 현재 상태를 나타내며, 저장 및 로드가 가능함.
주요 특징
- OrderedDict 형태:
state_dict()
는- attribute 이름을 키로 하고,
- 그에 대응하는
torch.Tensor
를 값으로 갖는 collections.OrderedDict
객체를 반환.
- 학습 가능한 매개변수:
state_dict()
는torch.nn.Parameter
객체로 정의된 모든 학습 가능한 attributes를 포함.
- 버퍼:
- 모델에 포함된 모든 버퍼(예:
BatchNorm
계층의 running mean과 variance)도 포함.
- 모델에 포함된 모든 버퍼(예:
예제 코드
다음은 state_dict()
가 반환하는 객체에 대한 예제임.
import torch
import torch.nn as nn
from collections import OrderedDict
# 간단한 모델 정의
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(10, 5)
self.bn = nn.BatchNorm1d(5)
# self.sub = nn.Sequential(
# OrderedDict({
# 'ds00_layer': nn.Linear(5,5),
# 'ds00_bn':nn.BatchNorm1d(5),
# 'ds00_act': nn.ReLU(),
# 'ds01_layer': nn.Linear(5,5),
# })
# )
def forward(self, x):
x = self.linear(x)
x = self.bn(x)
return x
# 모델 인스턴스 생성
model = MyModel()
# 모델의 state_dict 가져오기
state_dict = model.state_dict()
# state_dict 내용 출력
for key, value in state_dict.items():
print(f"{key}: {value.shape}")
https://gist.github.com/dsaint31x/3e0f42c7ec25380dd947776f589522ac
메서드 parameters: state_dict(desitnation=None, prefix='', keep_vars=False)
- keep_vars 는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정
- keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
- value 가 파라메터인 경우엔 Parameter 로 얻어지고,
- value 가 버퍼인 경우엔 Tensor 로 얻어짐.
- keep_vars=True 인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.
- 모델 디버깅: 모델 상태를 조사하고 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
- 모델 커스터마이징: 모델을 불러온 후 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
- 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 유용
- 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
- 저장의 용도로는 keep_vars=False를 사용하는 게 좋음.
- keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
출력 예시
linear.weight: torch.Size([5, 10])
linear.bias: torch.Size([5])
bn.weight: torch.Size([5])
bn.bias: torch.Size([5])
bn.running_mean: torch.Size([5])
bn.running_var: torch.Size([5])
bn.num_batches_tracked: torch.Size([])
설명
linear.weight
:nn.Linear
계층의 가중치 파라미터linear.bias
:nn.Linear
계층의 바이어스 파라미터bn.weight
:nn.BatchNorm1d
계층의 가중치 파라미터bn.bias
:nn.BatchNorm1d
계층의 바이어스 파라미터bn.running_mean
:nn.BatchNorm1d
계층의 running mean 버퍼bn.running_var
:nn.BatchNorm1d
계층의 running variance 버퍼bn.num_batches_tracked
:nn.BatchNorm1d
계층에서 배치의 수를 추적하는 버퍼
활용 방법
모델 저장:
state_dict
를 파일에 저장하여 나중에 모델을 복원할 수 있음.
torch.save(model.state_dict(), 'model_state.pth')
모델 로드:
저장된 state_dict
를 로드하여 모델을 복원.
model = MyModel()
model.load_state_dict(torch.load('model_state.pth'))
model.eval() # 평가 모드로 전환 (선택 사항)`
파라미터 업데이트:
state_dict
를 사용하여 모델의 특정 파라미터를 업데이트할 수 있음.
state_dict['linear.weight'] = torch.ones_like(state_dict['linear.weight'])
model.load_state_dict(state_dict)
결론
state_dict()
는 PyTorch에서 모델의 학습 가능한 매개변수와 버퍼의 상태를 관리하는 데 중요한 역할을 당담함.
이를 통해 모델을 쉽게 저장하고 로드할 수 있으며, 모델 파라미터의 직접적인 접근 및 수정도 가능합니다.