본문 바로가기
카테고리 없음

[DL] Torch: Save and Load Model

by ds31x 2024. 5. 16.

Torch: Save and Load Model

PyTorch에서 model을 저장하는 방법은 크게 두 가지임.

  1. 모델의 Parameters (= weights and bias)를 저장 (Structure 등은 저장되지 않음).
  2. model 전체를 저장하는 방법 (Parameters와 Structure 함께)

일반적으로 권장되는 방법은 1번임.

1번의 경우,

  • 비록 모델의 구조를 정의하고 있는 class 의 instance 코드 상에서 생성하고,
  • 이 instance로 로딩을 수행해줘야 하지만,
  • 해당 class의 소스를 정확히 가지고 있을 경우 PyTorch 버전 등에 상관없이
  • 이전과 동일한 모델을 load를 통해 얻을 수 있음.

위에서 정확히 가지고 있다라는 애기는 save할 때와 load할 때의 모델 클래스의 definition이 동일해야함을 의미함.


1. 모델의 parameters를 저장하고 읽어들이기

이 방식은 torch.nn.Module의 인스턴스가 가지고 있는
state_dict (collections.OrderedDict 의 인스턴스)를 사용한다.

 

state_dict

  • Module에 해당하는 모델
    • 학습 가능한 모든 parameters의 상태와
    • 속성(attribute)으로 가지고 있는 buffers의 상태를
  • 저장하고 있는 OrderedDict의 객체임.

말 그대로,

state_dict는 현재 Module 객체의 상태를 가지고 있으며,

이를 통해 현재 모델의 상태를 저장하고 이를 통해 해당 상태로 복원할 수 있음.


1.1. Module의 관련 methods

1.1.1. state_dict() 메서드

현재 Module 인스턴스의 상태 에 해당하는 OrderedDict 객체 state_dict를 반환.

.state_dict(desitnation=None, prefix='', keep_vars=False)

  • keep_vars 는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정
    • keep_vars=True 인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.
      • value 가 파라메터인 경우엔 Parameter 로 얻어지고,
      • value 가 버퍼인 경우엔 Tensor 로 얻어짐.
    • keep_vars=True 인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.
      • 모델 디버깅: 모델 상태를 조사하고 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
      • 모델 커스터마이징: 모델을 불러온 후 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
      • 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 유용
    • 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
    • 저장의 용도로는 keep_vars=False를 사용하는 게 좋음.

1.1.2. load_state_dict( state_dict ) 메서드

현재 Module 인스턴스의 상태를 argument로 넘겨진 OrderedDict 객체 state_dict를 이용 하여 설정함.

 

이 메서드의 반환값은

torch.nn.modules.module._IncompatibleKeys 객체로서,

Module 인스턴스에 state_dict를 로드할 때 호환되지 않는 키들의 정보를 가지고 있다.

 

다음과 같은 2개의 attributes를 가지며 이를 통해 모델의 Parameters를 복원하는데 발생한 문제를 해결하기 위한 조치를 취할 수 있음.

  • missing_keys : 로드하려는 state_dict에는 있으나 load_state_dict메서드를 호출한 Module 객체에는 없는 키들.
  • unexpected_keys : Module 객체에는 있으나 인자로 넘겨진 state_dict에는 없는 키들.

1.1.3. torch.save(state_dict,'file_path') 와 state_dict=torch.load('file_path')`

1.1.1과 1.1.2의 state_dict 객체는 torch 모듈의 saveload 함수를 이용하여 파일로 저장(serialization)되거나 로딩(instantiazation)됨.


1.3 Example

다음의 예제 코드는 state_dict 를 이용하여 Parameters만을 저장하고 로드하는 방식을 보여준다.

# 필요한 library와 모듈 import
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict

# 간단한 linear regression model 정의.
# SimpleModel0와 SimpleModel1은 
# 똑같은 구조이나 파라메터들의 초기값만 다름.
class SimpleModel0(nn.Module):

  def __init__(self, n_in_f, n_out_f):

    super().__init__()

    self.l0 = nn.Linear(n_in_f, n_out_f)

    const_weight = 1.
    const_bias = 0.5
    init.constant_(self.l0.weight, const_weight)
    if self.l0.bias is not None:
      init.constant_(self.l0.bias, const_bias)

  def forward(self, x):
    return self.l0(x)

class SimpleModel1(nn.Module):

  def __init__(self, n_in_f, n_out_f):

    super().__init__()

    self.l0 = nn.Linear(n_in_f, n_out_f)

    const_weight = 2.
    const_bias = 1.5

    init.constant_(self.l0.weight, const_weight)
    if self.l0.bias is not None:
      init.constant_(self.l0.bias, const_bias)

  def forward(self, x):
    return self.l0(x)


# 모델 객체를 생성하고, 이에 대한 파라메터 확인후 
# 파라메터만 저장.
model = SimpleModel(3,1)
print(list(model.named_parameters()))
torch.save(model.state_dict(), 'model_params.pth')

# 새로운 모델 객체를 생성.
# 해당 모델 객체는 구조는 같으나, 파라메터들의 초기값은 다름.
n_model = SimpleModel1(3,1)

print('===============')
for old, new in zip(model.parameters(), n_model.parameters()):

  if not torch.equal(old,new):
    print('model and n_model w/ default init do not have parameters with the same values!')
    break
else:
  print('model and n_model w/ default init have parameters with the same values!')
print('===============')

# 이전 저장한 parameters에 대한 state_dict를 
# 로드하고 해당 state_dict로 새로만든 모델의
# 파라메터를 설정하고 이전 모델과 비교.
# load parameters and restore old parameters into new model
loaded_params_ordered_dict = torch.load('model_params.pth')
print(f'{type(loaded_params_ordered_dict)=}') # collections.OredredDict

ret_v = n_model.load_state_dict(loaded_params_ordered_dict)
print(f'{type(ret_v)}: {ret_v}')

print('===============')
for old, new in zip(model.parameters(), n_model.parameters()):

  if not torch.equal(old,new):
    print('model and n_model do not have parameters with the same values!')
    break
else:
  print('model and n_model have parameters with the same values!')

위의 코드에서는 2개의 모델이 같은 구조를 가지고 있으며,
state_dict를 통해 똑같은 파라메터를 가지도록 처리하고 이를 확인하는 예제 코드임.

 

load_state_dict 메서드의 반환값인 _IncompatibleKeys 객체는

아무 문제가 없을시 해당 객체의 __str__() 메서드를 통해 <All keys matched successfully> 라는 문자열을 출력함.

 

2024.05.16 - [분류 전체보기] - [DL] PyTorch: Tensor 비교하기.

 

[DL] PyTorch: Tensor 비교하기.

[DL] PyTorch: Tensor 비교하기PyTorch에서 nn.Parameter 또는 tensor 객체 두 개가 같은 값을 가지는지 확인하는 방법은 텐서의 모든 요소가 동일한지 확인하는 것임. 이를 위해 제공되는 다음과 같은 함수 2

ds31x.tistory.com

 


2. model 전체를 저장하고 로드 방법

이 방법은 모델의 전체 클래스와 인스턴스 상태를 한번에 저장함.

단, pickle에 의존하여 인스턴스를 직렬화하기 때문에 이 경우에도 모델 클래스 정의가 필요함.

  • 대응하는 모델의 클래스가 import되어 있던지
  • load 전에 파일에 저장된 클래스의 정의와 같이 앞서 정의되어 있어야 한다.

이 방법은

Python 버전 및 라이브러리 버전이 바뀌거나,

해당 모델 클래스가 의존하는 모듈이나 패키지의 버전이 변경되는 경우,

모델 클래스를 정의한 Python파일에서 클래스의 이름이 바뀌거나 모델의 구조나 메서드 또는 레이어가 추가 삭제되는 경우 등에서

로딩이 제대로 되지 못하는 문제가 발생할 수 있음.

때문에 state_dict를 이용하는 방법이 보다 선호됨.


다음 예제는 모델을 통째로 저장하는 방법을 간단히 보여주는 예제임.

 

또한, 저장한 모델을 로드한 후 원본 모델과 로드된 모델의 파라미터가 동일한지 확인하는 방법을 예시로 보여줌.

# 필요한 모듈 import
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict

# 사용할 간단한 linear regression model 정의
class SimpleModel0(nn.Module):

  def __init__(self, n_in_f, n_out_f):

    super().__init__()

    init_weigths = torch.ones( (n_in_f, n_out_f) )
    init_bias = torch.zeros( (n_out_f,) )

    self.l0 = nn.Linear(n_in_f, n_out_f)

    const_weight = 2.
    const_bias = 1.5

    init.constant_(self.l0.weight, const_weight)
    if self.l0.bias is not None:
      init.constant_(self.l0.bias, const_bias)

  def forward(self, x):
    return self.l0(x)

# 저장할 모델 생성.
model = SimpleModel0(3,1)

# 모델 저장
torch.save(model, 'model.pth')

# 저장된 model 로드.
n_model = torch.load('model.pth')
print(f'{type(n_model)=}, {n_model}')

# 두 모델의 parameters비교.
for old, new in zip(model.parameters(), n_model.parameters()):
  if not torch.equal(old,new):
    print('model and n_model do not have parameters with the same values!')
    break
else:
  print('model and n_model have parameters with the same values!')

필요한 모듈 import

  • torch: PyTorch의 기본 모듈
  • nn: PyTorch의 신경망 모듈
  • init: PyTorch의 초기화 모듈
  • OrderedDict: 정렬된 딕셔너리 자료구조를 위한 모듈

사용할 간단한 linear regression model 정의

  • SimpleModel0 클래스는 nn.Modul을 상속받아 정의한 간단한 선형 회귀 모델.
  • __init__ 생성자에서 입력 피처 수(n_in_f)와 출력 피처 수(n_out_f)를 받아 linear layer self.l0를 정의.
  • init.constant_ 를 사용하여 linear layer의 parameter를 특정 상수값으로 초기.
  • forward 메서드는 입력 데이터를 받아 linear transform을 수행.

모델 저장

  • model = SimpleModel0(3, 1): 입력 피처 수가 3, 출력 피처 수가 1인 모델 인스턴스를 생성.
  • torch.save(model, 'model.pth'): 생성한 모델을 'model.pth' 파일에 저장합니다.

저장된 model 로드 및 두 모델의 parameters비교.

  • n_model = torch.load('model.pth'): 저장된 모델을 'model.pth' 파일에서 로드.
  • print(f'{type(n_model)=}, {n_model}'): 로드된 모델의 타입과 내용을 출력.
  • for old, new in zip(model.parameters(), n_model.parameters()): 원본 모델(model)과 로드된 모델(n_model)의 파라미터를 순회하며 비교.
  • torch.equal(old, new): 두 파라미터 텐서가 동일한지 확인.

같이 읽어보면 좋은 자료들

https://tutorials.pytorch.kr/beginner/basics/saveloadrun_tutorial.html

 

모델 저장하고 불러오기

파이토치(PyTorch) 기본 익히기|| 빠른 시작|| 텐서(Tensor)|| Dataset과 Dataloader|| 변형(Transform)|| 신경망 모델 구성하기|| Autograd|| 최적화(Optimization)|| 모델 저장하고 불러오기 이번 장에서는 저장하기나

tutorials.pytorch.kr

 

https://tutorials.pytorch.kr/recipes/recipes/saving_multiple_models_in_one_file.html#

 

PyTorch에서 여러 모델을 하나의 파일에 저장하기 & 불러오기

여러 모델을 저장하고 불러오는 것은 이전에 학습했던 모델들을 재사용하는데 도움이 됩니다. 개요: GAN이나 시퀀스-투-시퀀스(sequence-to-sequence model), 앙상블 모델(ensemble of models)과 같이 여러 torch

tutorials.pytorch.kr

 

https://gist.github.com/dsaint31x/ed823b5d11206ba58e9da9ced41baa8b

 

dl_torch_save_load.ipynb

dl_torch_save_load.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 

728x90