PyTorch: state_dict()
torch.nn.Module 객체의 state_dict() 메서드는
- 모델의 학습 가능한 매개변수(가중치와 바이어스)의 상태와
- 버퍼(예: BatchNorm의 running mean과 variance 등)의 상태를 저장하는
collections.OrderedDict객체를 반환.
반환된 객체는 모델의 현재 상태를 나타내며, 저장 및 로드가 가능함.
2025.04.04 - [Python] - [Py] collection.OrderedDict
[Py] collection.OrderedDict
OrderedDict는 삽입된 순서를 보존하는 기능을 추가한 일종의 dict임.collections 모듈에서 제공.Python 3.7부터 built-in dict도 삽입 순서를 보존하게 되었음.하지만 OrderedDict는 그 외에 몇 가지 중요한 추가
ds31x.tistory.com
주요 특징
- OrderedDict 형태:
state_dict()는- attribute 이름을 키로 하고,
- 그에 대응하는
torch.Tensor를 값으로 갖는 collections.OrderedDict객체를 반환.
- 학습 가능한 매개변수:
state_dict()는torch.nn.Parameter객체로 정의된 모든 학습 가능한 attributes를 포함.
- 버퍼:
- 모델에 포함된 모든 버퍼(예:
BatchNorm계층의 running mean과 variance)도 포함.
- 모델에 포함된 모든 버퍼(예:
예제 코드
다음은 state_dict()가 반환하는 객체에 대한 예제임.
import torch
import torch.nn as nn
from collections import OrderedDict
# 간단한 신경망 모델 클래스 정의
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 10차원 입력을 5차원으로 변환하는 선형 레이어 정의
self.linear = nn.Linear(10, 5)
# 5차원 특성에 대한 배치 정규화 레이어 정의
self.bn = nn.BatchNorm1d(5)
def forward(self, x):
# 선형 변환 적용
x = self.linear(x)
# 배치 정규화 적용
x = self.bn(x)
return x
# 모델 인스턴스 생성
model = MyModel()
# 모델의 state_dict 획득
state_dict = model.state_dict()
# 파라미터와 버퍼 구분을 위한 출력 섹션
print("=== 파라미터 (학습 가능한 가중치) ===")
# named_parameters() 메서드를 통한 모든 파라미터 출력
for name, param in model.named_parameters():
print(f"{name}: {param.shape}, requires_grad={param.requires_grad}")
print("\n=== 버퍼 (학습 불가능한 상태 값) ===")
# named_buffers() 메서드를 통한 모든 버퍼 출력
for name, buf in model.named_buffers():
print(f"{name}: {buf.shape}, requires_grad={buf.requires_grad}")
print("\n=== 전체 state_dict 내용 ===")
# 파라미터와 버퍼 키 목록 생성
param_keys = [name for name, _ in model.named_parameters()]
buffer_keys = [name for name, _ in model.named_buffers()]
# state_dict의 각 항목에 대한 유형 구분
for key, value in state_dict.items():
if key in param_keys:
print(f"{key}: {value.shape} (파라미터)")
elif key in buffer_keys:
print(f"{key}: {value.shape} (버퍼)")
else:
print(f"{key}: {value.shape} (기타)")
# 예시 출력 설명:
# - linear.weight, linear.bias: 선형 레이어의 학습 가능한 파라미터
# - bn.weight, bn.bias: 배치 정규화의 학습 가능한 파라미터
# - bn.running_mean, bn.running_var: 배치 정규화의 통계적 버퍼 값
# - bn.num_batches_tracked: 배치 정규화의 추적용 버퍼 값
https://gist.github.com/dsaint31x/3e0f42c7ec25380dd947776f589522ac
dl_torch_state_dict_keep_vars.ipynb
dl_torch_state_dict_keep_vars.ipynb. GitHub Gist: instantly share code, notes, and snippets.
gist.github.com
메서드 parameters:
state_dict(destination=None, prefix='', keep_vars=False)->OrderedDict
- keep_vars 는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정
- keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
- value 가 파라메터인 경우엔 parameter 객체로 얻어지고,
- value 가 버퍼인 경우엔 tensor 로 얻어짐.
- keep_vars=True 인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.
- 모델 디버깅: 모델 상태를 조사하고 특정 parameter나 buffer의 값을 변경을 가능케 함.
- 모델 커스터마이징: 모델을 불러온 후 특정 parameter나 buffer의 값을 변경해야 하는 경우 유용
- 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 사용.
- 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
- 저장의 용도로는 keep_vars=False를 사용하는 게 좋음.
- keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
- destination
- 기본값: None
- 역할: 상태 딕셔너리를 저장할 대상 OrderedDict
- 설명:
- 이 매개변수가 지정되면, 메서드는 새 OrderedDict 객체를 생성하지 않고 이 딕셔너리에 상태를 추가함.
- 일반적으로 None으로 두면 자동으로 새 OrderedDict 객체를 생성.
- 활용 예: 여러 모델의 상태를 하나의 딕셔너리에 합칠 때 유용.
- prefix
- 기본값: '' (빈 문자열)
- 역할: 모든 파라미터 키의 접두사
- 설명: 상태 딕셔너리의 모든 키 앞에 붙는 문자열로, 일반적으로 모듈 계층 구조를 나타낼 때 사용됨.
- 활용 예: 서브모듈의 상태를 구분하기 위해 'encoder.', 'decoder.' 등의 접두사를 사용할 수 있음.
destination과 prefix를 사용한 예제코드
import torch
import torch.nn as nn
from collections import OrderedDict
# 간단한 모델 정의
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 5)
self.bn = nn.BatchNorm1d(5)
def forward(self, x):
x = self.linear(x)
x = self.bn(x)
return x
# 더 복잡한 모델 정의 (서브모듈 포함)
class ComplexModel(nn.Module):
def __init__(self):
super(ComplexModel, self).__init__()
self.encoder = SimpleModel()
self.decoder = nn.Linear(5, 2)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 모델 인스턴스 생성
simple_model = SimpleModel()
complex_model = ComplexModel()
# 예시 1: 기본 사용
print("=== 예시 1: 기본 사용 ===")
state_dict = simple_model.state_dict()
print("기본 state_dict 타입:", type(state_dict))
print("기본 state_dict 내용:")
for key, value in state_dict.items():
print(f" {key}: {value.shape}")
# 예시 2: prefix 매개변수 사용
print("\n=== 예시 2: prefix 매개변수 사용 ===")
prefixed_dict = simple_model.state_dict(prefix='model1.')
print("prefix 적용된 state_dict 내용:")
for key, value in prefixed_dict.items():
print(f" {key}: {value.shape}")
# 예시 3: destination 매개변수 사용
print("\n=== 예시 3: destination 매개변수 사용 ===")
combined_dict = OrderedDict()
simple_model.state_dict(destination=combined_dict, prefix='model1.')
print("첫 번째 모델 추가 후 destination 내용:")
for key, value in combined_dict.items():
print(f" {key}: {value.shape}")
# 두 번째 모델도 같은 destination에 추가
simple_model.state_dict(destination=combined_dict, prefix='model2.')
print("\n두 번째 모델 추가 후 destination 내용:")
for key, value in combined_dict.items():
print(f" {key}: {value.shape}")
# 예시 4: keep_vars 매개변수 사용
print("\n=== 예시 4: keep_vars 매개변수 사용 ===")
vars_dict = simple_model.state_dict(keep_vars=True)
print("keep_vars=True 적용된 state_dict 내용:")
for key, tensor in vars_dict.items():
print(f" {key}: 형태={tensor.shape}, 텐서 객체={type(tensor)}, requires_grad={tensor.requires_grad}")
# 예시 5: 복잡한 모델 계층 구조
print("\n=== 예시 5: 복잡한 모델 계층 구조 ===")
complex_dict = complex_model.state_dict()
print("복잡한 모델의 state_dict 내용:")
for key, value in complex_dict.items():
print(f" {key}: {value.shape}")
실행결과
=== 예시 1: 기본 사용 ===
기본 state_dict 타입: <class 'collections.OrderedDict'>
기본 state_dict 내용:
linear.weight: torch.Size([5, 10])
linear.bias: torch.Size([5])
bn.weight: torch.Size([5])
bn.bias: torch.Size([5])
bn.running_mean: torch.Size([5])
bn.running_var: torch.Size([5])
bn.num_batches_tracked: torch.Size([])
=== 예시 2: prefix 매개변수 사용 ===
prefix 적용된 state_dict 내용:
model1.linear.weight: torch.Size([5, 10])
model1.linear.bias: torch.Size([5])
model1.bn.weight: torch.Size([5])
model1.bn.bias: torch.Size([5])
model1.bn.running_mean: torch.Size([5])
model1.bn.running_var: torch.Size([5])
model1.bn.num_batches_tracked: torch.Size([])
=== 예시 3: destination 매개변수 사용 ===
첫 번째 모델 추가 후 destination 내용:
model1.linear.weight: torch.Size([5, 10])
model1.linear.bias: torch.Size([5])
model1.bn.weight: torch.Size([5])
model1.bn.bias: torch.Size([5])
model1.bn.running_mean: torch.Size([5])
model1.bn.running_var: torch.Size([5])
model1.bn.num_batches_tracked: torch.Size([])
두 번째 모델 추가 후 destination 내용:
model1.linear.weight: torch.Size([5, 10])
model1.linear.bias: torch.Size([5])
model1.bn.weight: torch.Size([5])
model1.bn.bias: torch.Size([5])
model1.bn.running_mean: torch.Size([5])
model1.bn.running_var: torch.Size([5])
model1.bn.num_batches_tracked: torch.Size([])
model2.linear.weight: torch.Size([5, 10])
model2.linear.bias: torch.Size([5])
model2.bn.weight: torch.Size([5])
model2.bn.bias: torch.Size([5])
model2.bn.running_mean: torch.Size([5])
model2.bn.running_var: torch.Size([5])
model2.bn.num_batches_tracked: torch.Size([])
=== 예시 4: keep_vars 매개변수 사용 ===
keep_vars=True 적용된 state_dict 내용:
linear.weight: 형태=torch.Size([5, 10]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
linear.bias: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
bn.weight: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
bn.bias: 형태=torch.Size([5]), 텐서 객체=<class 'torch.nn.parameter.Parameter'>, requires_grad=True
bn.running_mean: 형태=torch.Size([5]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False
bn.running_var: 형태=torch.Size([5]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False
bn.num_batches_tracked: 형태=torch.Size([]), 텐서 객체=<class 'torch.Tensor'>, requires_grad=False
=== 예시 5: 복잡한 모델 계층 구조 ===
복잡한 모델의 state_dict 내용:
encoder.linear.weight: torch.Size([5, 10])
encoder.linear.bias: torch.Size([5])
encoder.bn.weight: torch.Size([5])
encoder.bn.bias: torch.Size([5])
encoder.bn.running_mean: torch.Size([5])
encoder.bn.running_var: torch.Size([5])
encoder.bn.num_batches_tracked: torch.Size([])
decoder.weight: torch.Size([2, 5])
decoder.bias: torch.Size([2])
- 기본 사용: state_dict()는 모든 파라미터와 버퍼를 OrderedDict 형태로 반환.
- prefix 사용: 모든 키 앞에 지정된 접두사가 붙습니다. 이는 여러 모델의 상태를 구분할 때 유용함.
- destination 사용:
- 같은 OrderedDict에 여러 모델의 상태를 합칠 수 있음.
- 이는 앙상블 모델이나 복합 모델 저장 시 유용함.
- keep_vars 사용:
- True로 설정하면 값 대신 실제 텐서 객체가 저장됨.
- 파라미터(nn.Parameter)와 버퍼(torch.Tensor)의 구분이 명확하게 보임.
- 계층적 구조: 복잡한 모델에서는 서브모듈의 이름이 자동으로 키의 접두사로 사용됨.
출력 예시
=== 파라미터 (학습 가능한 가중치) ===
linear.weight: torch.Size([5, 10]), requires_grad=True
linear.bias: torch.Size([5]), requires_grad=True
bn.weight: torch.Size([5]), requires_grad=True
bn.bias: torch.Size([5]), requires_grad=True
=== 버퍼 (학습 불가능한 상태 값) ===
bn.running_mean: torch.Size([5]), requires_grad=False
bn.running_var: torch.Size([5]), requires_grad=False
bn.num_batches_tracked: torch.Size([]), requires_grad=False
=== 전체 state_dict 내용 ===
linear.weight: torch.Size([5, 10]) (파라미터)
linear.bias: torch.Size([5]) (파라미터)
bn.weight: torch.Size([5]) (파라미터)
bn.bias: torch.Size([5]) (파라미터)
bn.running_mean: torch.Size([5]) (버퍼)
bn.running_var: torch.Size([5]) (버퍼)
bn.num_batches_tracked: torch.Size([]) (버퍼)
설명
linear.weight:nn.Linear계층의 가중치 파라미터linear.bias:nn.Linear계층의 바이어스 파라미터bn.weight:nn.BatchNorm1d계층의 가중치 파라미터bn.bias:nn.BatchNorm1d계층의 바이어스 파라미터bn.running_mean:nn.BatchNorm1d계층의 running mean 버퍼bn.running_var:nn.BatchNorm1d계층의 running variance 버퍼bn.num_batches_tracked:nn.BatchNorm1d계층에서 배치의 수를 추적하는 버퍼
기타?
일반적으로 PyTorch 모델의 state_dict에 포함된 텐서는 파라미터(named_parameters())와 버퍼(named_buffers())의 두 가지 카테고리로 구분되며, 이 두 가지에 속하지 않는 "기타" 항목은 실제로는 거의 발생하지 않음.
위의 코드에서 "기타" 카테고리를 포함한 이유는 다음과 같은 특별한 경우의 값들을 다루기 위해서임.
- 커스텀 등록된 상태 값: register_state_dict_hook을 사용하여 사용자가 임의로 state_dict에 추가한 값.
- 미래 버전 호환성: PyTorch의 향후 버전에서 새로운 유형의 상태 값이 추가될 가능성에 대비한 값.
- 외부 모듈 또는 확장 기능: 일부 PyTorch 확장 또는 외부 라이브러리가 표준 파라미터나 버퍼가 아닌 다른 형태의 상태를 추가할 수도 있음.
- 모델 양자화 관련 정보: 양자화된 모델의 경우 양자화 스케일이나 제로 포인트와 같은 추가 정보가 저장될 수 있음.
일반적인 PyTorch 모델(예: 제공된 MyModel 클래스)에서는 모든 상태가 parameter나 buffer로 분류되므로 "기타" 카테고리에 속하는 항목은 출력되지 않음.
그러나 확장성을 위해 이 둘에 속하지 않는 값이 있을 수 있음.
기타 항목 추가하기
- 아래의 코드는 state_dict 메서드를 override하여 커스텀 항목 'custom_item'을 추가함.
- 이 항목은 named_parameters()나 named_buffers()에 포함되지 않으므로 "기타" 카테고리로 분류됨.
import torch
import torch.nn as nn
from collections import OrderedDict
# 간단한 신경망 모델 클래스 정의
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 10차원 입력을 5차원으로 변환하는 선형 레이어 정의
self.linear = nn.Linear(10, 5)
# 5차원 특성에 대한 배치 정규화 레이어 정의
self.bn = nn.BatchNorm1d(5)
# 일반적인 속성으로 텐서 추가 (state_dict에 자동으로 포함되지 않음)
self.custom_tensor = torch.ones(3, 4)
def forward(self, x):
# 선형 변환 적용
x = self.linear(x)
# 배치 정규화 적용
x = self.bn(x)
return x
# state_dict 메서드 오버라이드
def state_dict(self, *args, **kwargs):
# 기본 state_dict 획득
state_dict_orig = super().state_dict(*args, **kwargs)
# 커스텀 state_dict 생성 및 원본 복사
state_dict_new = OrderedDict(state_dict_orig)
# 커스텀 항목 추가
state_dict_new['custom_item'] = self.custom_tensor
return state_dict_new
# 모델 인스턴스 생성
model = MyModel()
# 모델의 state_dict 획득
state_dict = model.state_dict()
# 파라미터와 버퍼 구분을 위한 출력 섹션
print("=== 파라미터 (학습 가능한 가중치) ===")
# named_parameters() 메서드를 통한 모든 파라미터 출력
for name, param in model.named_parameters():
print(f"{name}: {param.shape}, requires_grad={param.requires_grad}")
print("\n=== 버퍼 (학습 불가능한 상태 값) ===")
# named_buffers() 메서드를 통한 모든 버퍼 출력
for name, buf in model.named_buffers():
print(f"{name}: {buf.shape}, requires_grad={buf.requires_grad}")
print("\n=== 전체 state_dict 내용 ===")
# 파라미터와 버퍼 키 목록 생성
param_keys = [name for name, _ in model.named_parameters()]
buffer_keys = [name for name, _ in model.named_buffers()]
# state_dict의 각 항목에 대한 유형 구분
for key, value in state_dict.items():
if key in param_keys:
print(f"{key}: {value.shape} (파라미터)")
elif key in buffer_keys:
print(f"{key}: {value.shape} (버퍼)")
else:
print(f"{key}: {value.shape} (기타)")
결과는 다음과 같음.
=== 파라미터 (학습 가능한 가중치) ===
linear.weight: torch.Size([5, 10]), requires_grad=True
linear.bias: torch.Size([5]), requires_grad=True
bn.weight: torch.Size([5]), requires_grad=True
bn.bias: torch.Size([5]), requires_grad=True
=== 버퍼 (학습 불가능한 상태 값) ===
bn.running_mean: torch.Size([5]), requires_grad=False
bn.running_var: torch.Size([5]), requires_grad=False
bn.num_batches_tracked: torch.Size([]), requires_grad=False
=== 전체 state_dict 내용 ===
linear.weight: torch.Size([5, 10]) (파라미터)
linear.bias: torch.Size([5]) (파라미터)
bn.weight: torch.Size([5]) (파라미터)
bn.bias: torch.Size([5]) (파라미터)
bn.running_mean: torch.Size([5]) (버퍼)
bn.running_var: torch.Size([5]) (버퍼)
bn.num_batches_tracked: torch.Size([]) (버퍼)
custom_item: torch.Size([3, 4]) (기타)
활용 방법
모델 저장:
state_dict를 파일에 저장하여 나중에 모델을 복원할 수 있음.
torch.save(model.state_dict(), 'model_state.pth')
모델 로드:
저장된 state_dict를 로드하여 모델을 복원.
model = MyModel()
model.load_state_dict(torch.load('model_state.pth'))
model.eval() # 평가 모드로 전환 (선택 사항)`
파라미터 업데이트:
state_dict를 사용하여 모델의 특정 파라미터를 업데이트할 수 있음.
state_dict['linear.weight'] = torch.ones_like(state_dict['linear.weight'])
model.load_state_dict(state_dict)
결론
state_dict()는 PyTorch에서 모델의 학습 가능한 매개변수와 버퍼의 상태를 관리하는 데 중요한 역할을 당담함.
이를 통해 모델을 쉽게 저장하고 로드할 수 있으며, 모델 파라미터의 직접적인 접근 및 수정도 가능합니다.