본문 바로가기
목차
Python

[PyTorch] torch.save 와 torch.load - tensor 위주

by ds31x 2025. 4. 8.
728x90
반응형

PyTorch에서 tensor 저장 및 불러오기:
torch.save와 torch.load 사용법

PyTorch에서는 학습된 모델과 개별 tensor나 여러 tensor들의 집합을 저장하고 불러올 수 있음.

이 문서에서는 torch.savetorch.load 함수를 사용하여 tensor를 저장하고 불러오는 방법을 소개함.

 

모델 저장 및 로드는 다음을 참고

2024.05.16 - [분류 전체보기] - [DL] Torch: Save and Load Model

 

[DL] Torch: Save and Load Model

Torch: Save and Load ModelPyTorch에서 model을 저장하는 방법은 크게 두 가지임.모델의 Parameters (= weights and bias)를 저장 (Structure 등은 저장되지 않음).model 전체를 저장하는 방법 (Parameters와 Structure 함께)

ds31x.tistory.com


1. torch.save와 torch.load 기본 문법

1.1 torch.save 함수

torch.save(
    obj, 
    f, 
    pickle_module=pickle, 
    pickle_protocol=DEFAULT_PROTOCOL, 
    _use_new_zipfile_serialization=True,
)
  • obj: 저장할 객체 (tensor, tensor를 item으로 가지는 dict 등)
  • f: 파일명 또는 파일 객체
  • pickle_module: 직렬화에 사용할 모듈 (기본값: Python의 pickle)
  • pickle_protocol: pickle 프로토콜 버전
  • _use_new_zipfile_serialization: zip 파일 형식으로 저장할지 여부 (PyTorch 1.6부터 기본값: True)

1.2 torch.load 함수

torch.load(
    f, 
    map_location=None, 
    pickle_module=pickle, 
    **pickle_load_args,
)
  • f: 파일명 또는 파일 객체
  • map_location: tensor를 불러올 device(장치) 지정 (CPU/GPU)
  • pickle_module: 역직렬화에 사용할 module
  • pickle_load_args: pickle 모듈에 전달할 추가 인자

2. 단일 tensor 저장 및 불러오기

가장 기본적인 사용법은 하나의 tensor를 저장하고 불러오는 것임.

import torch

# 단일 텐서 생성
tensor = torch.tensor([1, 2, 3, 4, 5])
print(f"원본 텐서: {tensor}")

# 텐서 저장
torch.save(tensor, 'single_tensor.pt')

# 텐서 불러오기
loaded_tensor = torch.load('single_tensor.pt')
print(f"불러온 텐서: {loaded_tensor}")

 

출력:

원본 텐서: tensor([1, 2, 3, 4, 5])
불러온 텐서: tensor([1, 2, 3, 4, 5])

2.1 다양한 속성을 가진 tensor 저장하기

텐서의 데이터 타입, 크기, 장치 정보 등 다양한 속성이 함께 저장됨.

단, 계산 그래프는 저장되지 않음 (아래 "2.3 requires_grad=True 인 tensor 저장하기"절 참고).

# 다양한 속성을 가진 텐서 생성
float_tensor = torch.randn(3, 4, dtype=torch.float32)
long_tensor = torch.randint(0, 10, (2, 5), dtype=torch.int64)

print(f"Float 텐서 - 타입: {float_tensor.dtype}, 크기: {float_tensor.shape}")
print(f"Long 텐서 - 타입: {long_tensor.dtype}, 크기: {long_tensor.shape}")

# 텐서 저장
torch.save(float_tensor, 'float_tensor.pt')
torch.save(long_tensor, 'long_tensor.pt')

# 텐서 불러오기
loaded_float = torch.load('float_tensor.pt')
loaded_long = torch.load('long_tensor.pt')

print(f"불러온 Float 텐서 - 타입: {loaded_float.dtype}, 크기: {loaded_float.shape}")
print(f"불러온 Long 텐서 - 타입: {loaded_long.dtype}, 크기: {loaded_long.shape}")

2.2 GPU 메모리에 존재하던 tensor 저장 및 다른 device 로 로딩하기

텐서가 어떤 장치(CPU/GPU)에 있었는지 정보도 함께 저장되며, 기본적으로 해당 장치에 로드됨.

# 장치 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# GPU 텐서 생성 (GPU가 없으면 CPU에 생성됨)
gpu_tensor = torch.randn(3, 4).to(device)
print(f"원본 텐서 장치: {gpu_tensor.device}")

# 텐서 저장
torch.save(gpu_tensor, 'gpu_tensor.pt')

# CPU로 불러오기
# - map_location 으로 지정하지 않으면, 원래 device로 로드됨.
cpu_tensor = torch.load('gpu_tensor.pt', map_location='cpu') 
print(f"CPU로 불러온 텐서 장치: {cpu_tensor.device}")

# 다시 GPU로 불러오기 (GPU가 있는 경우)
if torch.cuda.is_available():
    gpu_loaded = torch.load('gpu_tensor.pt')  # 원래 장치로 불러옴
    print(f"GPU로 불러온 텐서 장치: {gpu_loaded.device}")

    # 특정 GPU로 불러오기
    specific_gpu = torch.load('gpu_tensor.pt', map_location='cuda:0')
    print(f"특정 GPU로 불러온 텐서 장치: {specific_gpu.device}")

2.3 requires_grad=True 인 tensor 저장하기

텐서의 값은 저장되지만 연산 그래프와 gradient 정보는 저장되지 않음.

load 이후엔 requires_grad 를 다시 True로 지정해야 함.

# requires_grad 속성이 있는 텐서 생성
grad_tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
print(f"원본 텐서 requires_grad: {grad_tensor.requires_grad}")

# 텐서 저장
torch.save(grad_tensor, 'grad_tensor.pt')

# 텐서 불러오기
loaded_grad = torch.load('grad_tensor.pt')
print(f"불러온 텐서 requires_grad: {loaded_grad.requires_grad}")  # False가 출력됨

# 필요하다면 requires_grad 다시 설정
loaded_grad.requires_grad_(True)
print(f"설정 후 requires_grad: {loaded_grad.requires_grad}")  # True가 출력됨

3. 여러 tensors를 dict 객체로 저장하기

실제 작업에서는 여러 텐서를 함께 저장해야 하는 경우가 많음.

  • dict을 사용하여 여러 텐서를 저장하는 방법이 주로 사용됨.
  • 엄밀히 말하면 dictionary보다 OrderedDict 를 PyTorch에선 모델 저장 및 로드에 사용함:
  • Python 3.7부터는 일반 dict도 순서를 보장하므로 읽어들일 때는 dict를 사용하는 추세

2025.04.04 - [Python] - [Py] collection.OrderedDict

 

[Py] collection.OrderedDict

OrderedDict는 삽입된 순서를 보존하는 기능을 추가한 일종의 dict임.collections 모듈에서 제공.Python 3.7부터 built-in dict도 삽입 순서를 보존하게 되었음.하지만 OrderedDict는 그 외에 몇 가지 중요한 추가

ds31x.tistory.com


3.1 기본적인 tensors를 dict로 저장하기

# 여러 텐서 생성
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(5, 6)
tensor3 = torch.randn(2, 2)

# 텐서들을 사전에 담기
tensor_dict = {
    'tensor1': tensor1,
    'tensor2': tensor2,
    'tensor3': tensor3
}

# 사전 저장
torch.save(tensor_dict, 'tensor_dict.pt')

# 사전 불러오기
loaded_dict = torch.load('tensor_dict.pt')

# 불러온 사전 확인
for name, tensor in loaded_dict.items():
    print(f"{name} - 크기: {tensor.shape}")

 

출력:

tensor1 - 크기: torch.Size([3, 4])
tensor2 - 크기: torch.Size([5, 6])
tensor3 - 크기: torch.Size([2, 2])

3.2 다양한 정보를 포함한 복잡한 dict 객체 저장

tensor뿐만 아니라 scalar 값, list, 다른 dict 등 다양한 Python 객체를 포함한 복잡한 구조도 저장할 수 있음.

torch.save 와 torch.load는 내부적으론 pickle 모듈에 의존함.

import numpy as np
from datetime import datetime

# 다양한 객체 생성
tensor_a = torch.randn(3, 4)
tensor_b = torch.ones(2, 3)
numpy_array = np.random.randn(4, 4)
python_list = [1, 2, 3, 4, 5]
nested_dict = {'a': 10, 'b': 20}
current_time = datetime.now()

# 복잡한 사전 구조 만들기
complex_dict = {
    'tensors': {
        'a': tensor_a,
        'b': tensor_b
    },
    'numpy_data': numpy_array,
    'python_data': {
        'list': python_list,
        'nested': nested_dict
    },
    'metadata': {
        'created_at': current_time,
        'description': '여러 객체를 포함한 복잡한 사전 예제'
    }
}

# 복잡한 사전 저장
torch.save(complex_dict, 'complex_dict.pt')

# 복잡한 사전 불러오기
loaded_complex = torch.load('complex_dict.pt')

# 불러온 사전 구조 확인
print(f"텐서 A 크기: {loaded_complex['tensors']['a'].shape}")
print(f"NumPy 배열 크기: {loaded_complex['numpy_data'].shape}")
print(f"Python 리스트: {loaded_complex['python_data']['list']}")
print(f"생성 시간: {loaded_complex['metadata']['created_at']}")

 

출력:

텐서 A 크기: torch.Size([3, 4])
NumPy 배열 크기: (4, 4)
Python 리스트: [1, 2, 3, 4, 5]
생성 시간: 2023-04-07 15:30:45.123456

3.3 파일 크기를 줄이기 위한 압축 사용

큰 tensor나 많은 tensors를 저장할 때는 파일 크기를 줄이기 위해 압축을 사용할 수 있음.

import torch

# 큰 텐서 생성
large_tensor = torch.randn(1000, 1000)

# 기본 방식으로 저장
torch.save(large_tensor, 'large_tensor_default.pt')

# 압축된 파일 객체 생성 및 저장
import gzip
with gzip.open('large_tensor_compressed.pt.gz', 'wb') as f:
    torch.save(large_tensor, f)

# 파일 크기 확인
import os
default_size = os.path.getsize('large_tensor_default.pt')
compressed_size = os.path.getsize('large_tensor_compressed.pt.gz')

print(f"기본 파일 크기: {default_size / 1024 / 1024:.2f} MB")
print(f"압축 파일 크기: {compressed_size / 1024 / 1024:.2f} MB")
print(f"압축률: {compressed_size / default_size * 100:.2f}%")

# 압축 파일 불러오기
with gzip.open('large_tensor_compressed.pt.gz', 'rb') as f:
    loaded_large = torch.load(f)

print(f"불러온 텐서 크기: {loaded_large.shape}")

4. 실용적인 사용 사례

4.1 학습 데이터 pre-processing 결과 저장

데이터 전처리 결과를 저장하여 다음 학습에 재사용할 수 있음.

import torch
import time

# 가상의 시간이 오래 걸리는 전처리 함수
def preprocess_data(raw_data):
    print("데이터 전처리 시작...")
    time.sleep(2)  # 실제로는 시간이 오래 걸리는 처리
    processed = raw_data * 2 + 1
    normalized = (processed - processed.mean()) / processed.std()
    return normalized

# 원본 데이터
raw_data = torch.randn(10000, 100)

# 전처리
start_time = time.time()
processed_data = preprocess_data(raw_data)
print(f"전처리 시간: {time.time() - start_time:.2f}초")

# 전처리 결과 저장
preprocessed_dict = {
    'data': processed_data,
    'preprocessing_info': {
        'mean': processed_data.mean().item(),
        'std': processed_data.std().item(),
        'timestamp': time.time()
    }
}
torch.save(preprocessed_dict, 'preprocessed_data.pt')

# 나중에 전처리 결과 불러오기
loaded_preprocessed = torch.load('preprocessed_data.pt')
print(f"불러온 데이터 크기: {loaded_preprocessed['data'].shape}")
print(f"전처리 통계: 평균 = {loaded_preprocessed['preprocessing_info']['mean']:.4f}, 표준편차 = {loaded_preprocessed['preprocessing_info']['std']:.4f}")

4.2 Mini-batch Caching

학습 중 mini-batch를 caching하여 디버깅이나 재현성 확인에 사용할 수 있음.

# 미니배치 저장 함수
def save_minibatch(inputs, targets, batch_idx):
    torch.save({
        'inputs': inputs,
        'targets': targets,
        'batch_idx': batch_idx,
        'timestamp': time.time(),
    }, f'minibatch_{batch_idx}.pt')

# 학습 루프 (예시)
batch_size = 32
for batch_idx in range(3):  # 예시를 위해 3개 배치만
    # 가상의 미니배치 생성
    inputs = torch.randn(batch_size, 10)
    targets = torch.randint(0, 2, (batch_size,))

    # 특정 조건에서 미니배치 저장 (예: 오류 발생 시)
    if batch_idx == 1:  # 예시를 위해 두 번째 배치 저장
        save_minibatch(inputs, targets, batch_idx)
        print(f"미니배치 {batch_idx} 저장됨")

    # 학습 코드...

# 저장된 미니배치 불러오기
loaded_batch = torch.load('minibatch_1.pt')
print(f"불러온 미니배치 정보:")
print(f"- 입력 크기: {loaded_batch['inputs'].shape}")
print(f"- 타겟 크기: {loaded_batch['targets'].shape}")
print(f"- 배치 인덱스: {loaded_batch['batch_idx']}")

4.3 Embedding Vector 저장

word나 image의 embedding vector 를 계산하여 저장해두면 나중에 다시 사용할 수 있음.

# 단어 임베딩 예제
vocab = ["apple", "banana", "orange", "grape", "watermelon"]
embedding_dim = 50

# 가상의 임베딩 계산 (실제로는 모델에서 계산)
word_embeddings = {word: torch.randn(embedding_dim) for word in vocab}

# 단어별 임베딩 저장
for word, embedding in word_embeddings.items():
    torch.save(embedding, f'embedding_{word}.pt')

# 모든 임베딩을 한 번에 저장
torch.save(word_embeddings, 'all_embeddings.pt')

# 필요한 임베딩만 불러오기
apple_embedding = torch.load('embedding_apple.pt')
print(f"사과 임베딩 크기: {apple_embedding.shape}")

# 모든 임베딩 불러오기
all_embeddings = torch.load('all_embeddings.pt')
print(f"단어 목록: {list(all_embeddings.keys())}")
print(f"바나나 임베딩: {all_embeddings['banana'][:5]}")  # 처음 5개 값만 출력

5. 고급 기능 및 주의사항

5.1 map_location 을 활용한 다양한 device 지정

map_location 파라미터를 사용하여 tensor를 다양한 장치로 로딩 가능 (기본은 저장 당시 있던 device).

# 장치 정보가 있는 텐서 저장
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensor_on_device = torch.randn(3, 4).to(device)
torch.save(tensor_on_device, 'tensor_on_device.pt')

# 다양한 방식으로 불러오기
# 1. 문자열로 장치 지정
tensor1 = torch.load('tensor_on_device.pt', map_location='cpu')

# 2. 장치 객체 사용
cpu_device = torch.device('cpu')
tensor2 = torch.load('tensor_on_device.pt', map_location=cpu_device)

# 3. 장치 매핑 사전 사용
tensor3 = torch.load('tensor_on_device.pt', map_location={'cuda:0': 'cpu'})

# 4. 함수로 지정
def cpu_map_fn(storage, loc):
    return storage.cpu()

tensor4 = torch.load('tensor_on_device.pt', map_location=cpu_map_fn)

# 결과 확인
print(f"불러온 텐서 장치 (문자열): {tensor1.device}")
print(f"불러온 텐서 장치 (장치 객체): {tensor2.device}")
print(f"불러온 텐서 장치 (매핑 사전): {tensor3.device}")
print(f"불러온 텐서 장치 (함수): {tensor4.device}")

5.2 pickle 프로토콜 버전 지정

pickle 프로토콜 버전을 지정하여 호환성과 성능을 조절할 수 있음.

import pickle

# 다양한 pickle 프로토콜 버전으로 저장
data = {'tensor': torch.randn(100, 100)}

# 프로토콜 2 (Python 2.3부터 지원, 하위 호환성 좋음)
torch.save(data, 'data_proto2.pt', pickle_protocol=2)

# 프로토콜 4 (Python 3.4부터 지원, 큰 객체 처리 더 효율적)
torch.save(data, 'data_proto4.pt', pickle_protocol=4)

# 최신 프로토콜 (Python 3.8부터 프로토콜 5 지원)
torch.save(data, 'data_latest.pt', pickle_protocol=pickle.HIGHEST_PROTOCOL)

# 파일 크기 비교
import os
size_proto2 = os.path.getsize('data_proto2.pt')
size_proto4 = os.path.getsize('data_proto4.pt')
size_latest = os.path.getsize('data_latest.pt')

print(f"프로토콜 2 파일 크기: {size_proto2} bytes")
print(f"프로토콜 4 파일 크기: {size_proto4} bytes")
print(f"최신 프로토콜 파일 크기: {size_latest} bytes")

 

pickle에 대한 참고 자료

2024.11.27 - [Python] - [Py] Serialization of Python: pickle

 

[Py] Serialization of Python: pickle

1. Python의 pickle 모듈Python의 pickle 모듈은 Python 객체를 직렬화(serialize)하여 파일 또는 메모리에 저장.저장된 데이터를 다시 역직렬화(deserialize)하여 원래 객체로 복원.데이터를 영구 저장하거나 네

ds31x.tistory.com


5.3 Custom pickle 모듈 사용

기본 pickle 모듈 대신 다른 모듈(예: dill, cloudpickle)을 사용할 수 있음.

# dill 모듈 사용 예 (먼저 설치 필요: pip install dill)
import dill

# 람다 함수를 포함한 데이터 (일반 pickle로는 직렬화 불가)
data_with_lambda = {
    'tensor': torch.randn(3, 3),
    'transform_fn': lambda x: x * 2 + 1
}

# dill 모듈로 저장
torch.save(data_with_lambda, 'data_with_lambda.pt', pickle_module=dill)

# dill 모듈로 불러오기
loaded_data = torch.load('data_with_lambda.pt', pickle_module=dill)

# 불러온 함수 사용
input_tensor = torch.ones(2, 2)
transformed = loaded_data['transform_fn'](input_tensor)
print(f"변환 결과: {transformed}")

5.4 파일 객체 사용

당연한 이야기지만, 파일명 대신 file 객체를 사용하여 저장할 수 있음.

# 파일 객체 사용 예
tensor = torch.randn(5, 5)

# 파일 객체로 저장
with open('tensor_file_obj.pt', 'wb') as f:
    torch.save(tensor, f)

# 파일 객체로 불러오기
with open('tensor_file_obj.pt', 'rb') as f:
    loaded_tensor = torch.load(f)

print(f"불러온 텐서 크기: {loaded_tensor.shape}")

# 메모리 파일 객체 사용 (파일 시스템 없이 메모리에서 직렬화/역직렬화)
import io
buffer = io.BytesIO()
torch.save(tensor, buffer)

# 버퍼 위치를 처음으로 되돌림
buffer.seek(0)

# 메모리에서 직접 불러오기
loaded_from_memory = torch.load(buffer)
print(f"메모리에서 불러온 텐서: {loaded_from_memory[:2, :2]}")  # 일부만 출력

6. 권장 사항 및 best practices

6.1 여러 tensors 저장 시 사전 구조화

여러 tensors를 저장할 때는 체계적인 사전 구조를 사용하는 것을 추천함

# 체계적인 구조로 여러 텐서 저장
tensor_collection = {
    'data': {
        'train': {
            'inputs': torch.randn(100, 10),
            'targets': torch.randint(0, 2, (100,))
        },
        'val': {
            'inputs': torch.randn(20, 10),
            'targets': torch.randint(0, 2, (20,))
        }
    },
    'metadata': {
        'description': '학습 및 검증 데이터',
        'input_dim': 10,
        'output_dim': 2,
        'train_size': 100,
        'val_size': 20,
        'created_at': time.time()
    }
}

torch.save(tensor_collection, 'structured_data.pt')

6.2 가급적 meta data 도 포함시켜 저장할 것

tensor와 함께 이를 설명해주는 meta data를 저장하는 것이 권장됨.

# 메타데이터 포함 예
data_tensor = torch.randn(1000, 50)
metadata = {
    'shape': data_tensor.shape,
    'dtype': str(data_tensor.dtype),
    'mean': data_tensor.mean().item(),
    'std': data_tensor.std().item(),
    'min': data_tensor.min().item(),
    'max': data_tensor.max().item(),
    'created_at': datetime.now().isoformat(),
    'description': '1000개 샘플의 50차원 특성 벡터'
}

torch.save({
    'tensor': data_tensor,
    'metadata': metadata
}, 'tensor_with_metadata.pt')

# 메타데이터만 먼저 확인 (텐서 전체를 불러오지 않음)
with open('tensor_with_metadata.pt', 'rb') as f:
    # 파일의 시작 부분만 불러와서 메타데이터 확인
    # 실제로는 더 복잡한 파일 구조 파싱이 필요할 수 있음
    loaded = torch.load(f)
    print("메타데이터:")
    for key, value in loaded['metadata'].items():
        print(f"- {key}: {value}")

6.3 파일 이름 규칙 정하기

체계적인 파일 이름 규칙을 사용하면 여러 텐서 파일을 관리하기 쉬움.

# 날짜와 설명이 포함된 파일 이름
import datetime

def save_tensor_with_naming(tensor, description):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    shape_str = '_'.join(str(dim) for dim in tensor.shape)
    filename = f"tensor_{description}_{shape_str}_{timestamp}.pt"

    torch.save({
        'tensor': tensor,
        'description': description,
        'shape': tensor.shape,
        'timestamp': timestamp
    }, filename)

    return filename

# 여러 텐서 저장
features = torch.randn(100, 20)
filename = save_tensor_with_naming(features, "features")
print(f"저장된 파일: {filename}")

labels = torch.randint(0, 5, (100,))
filename = save_tensor_with_naming(labels, "labels")
print(f"저장된 파일: {filename}")

6.4 대용량 tensor 처리

매우 큰 텐서를 저장하고 불러올 때는 메모리 효율성을 고려할 것.

# 대용량 텐서 처리 예
import gc

# 큰 텐서 생성
large_tensor = torch.randn(10000, 10000)  # 약 400MB
print(f"텐서 크기: {large_tensor.shape}, 메모리: {large_tensor.element_size() * large_tensor.nelement() / 1024 / 1024:.2f} MB")

# 저장
print("텐서 저장 중...")
torch.save(large_tensor, 'large_tensor.pt')

# 메모리 해제
del large_tensor
gc.collect()
print("원본 텐서 메모리 해제됨")

# 불러오기 전 정보 확인
import os
file_size = os.path.getsize('large_tensor.pt') / 1024 / 1024
print(f"파일 크기: {file_size:.2f} MB")

# 큰 텐서 불러오기
print("텐서 불러오는 중...")
loaded_large = torch.load('large_tensor.pt')
print(f"불러온 텐서 크기: {loaded_large.shape}")

7. 문제 해결 및 주의사항 (작성중)

7.1 버전 호환성 문제

다른 PyTorch 버전에서 저장한 텐서를 불러올 때 호환성 문제가 발생할 수 있음.

# 버전 호환성 문제 처리 예시
try:
    incompatible_tensor = torch.load('tensor_from_other_version.pt')
except Exception as e:
    print(f"텐서 불러오기 오류: {e}")

    # 다른 방법으로 시도
    try:
        # pickle_module과 인코딩 설정 변경
        incompatible_tensor = torch.load(
            'tensor_from_other_version.pt',
            pickle_module=pickle,
            encoding='latin1'  # 구버전 Python 2 호환성을 위해
        )
        print("대체 방법으로 불러오기 성공")
    except Exception as e:
        print(f"대체 방법도 실패: {e}")
728x90

'Python' 카테고리의 다른 글

[DL] PyTorch-Hook  (0) 2025.04.10
[DL] torch.nn.Linear 에 대하여  (1) 2025.04.10
[Py] 연습문제-carriage return + time.sleep  (0) 2025.04.07
[OpenCV] macOS에서 Qt 지원하도록 빌드.  (0) 2025.04.05
[Py] collections.ChainMap  (0) 2025.04.04