Dataset 이란
- PyTorch 의
tensor
와 - 학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,
raw-data로부터 PyTorch의 기본데이터형 인 tensor
를 얻게 해주는 역할을 수행하는 class임.
일반적으로 task에 따라
다양한 형태의 raw-data가 있기 때문에,
각 task별로 Custom Dataset을 만드는 경우가 잦음.
Custom Dataset 만들기
torch.util.data
모듈의 Dataset
을 상속하고,
다음의 methods를 overriding해야만 함.
__init__(self)
:Dataset
인스턴스에 대한 생성자로 데이터셋에 대한 초기화를 담당.- raw-data에 따라 parameters를 자유롭게 추가할 수 있음.
__len__(self)
:Dataset
인스턴스 내에 있는 샘플 갯수를 반환하도록 구현.
__getitem__(self, idx)
:- argument로 넘어오는
idx
에 해당하는 샘플을 반환하도록 구현. dataset[idx]
등으로Dataset
인스턴스인dataset
에 대해idx
에 해당하는 데이터셋 샘플에 접근할 때 호출됨.- 하나의 샘플은 input feature data와 label을 묶어서(
tuple
) 반환.
- argument로 넘어오는
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = torch.tensor(data).float()
self.labels = torch.tensor(labels).float()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
x = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]
y = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]
dataset = CustomDataset(x, y)
Dataset
은 iterable이기 때문에, 다음과 같이 각 샘플을 확인할 수 있음.
i = iter(dataset)
print(next(i))
# print(i.__next__())
단, collections.abc 의 Iterable을 직접 상속하고 있지는않음
issubclass(Dataset, Itearble)로 확인시 False.
2024.04.15 - [분류 전체보기] - [Python] collections.abc
https://dsaint31.tistory.com/501
DataLoader
Dataset을 통해 데이터 로딩을 수행하는 Class.
Dataset과 달리 사용법만 익혀도 일반적인 경우에는 충분함.
Dataset
으로부터 실제 Training Loop등에 training (mini)batch를 묶어서 효율적으로 제공해주는 역할을 수행*
- raw-data로부터
tensor
를 얻어내는 과정을 data loading 이라고 부르는데, - 일반적으로 storage에서 데이터를 읽어들이는 loading은 상대적으로 느린 속도를 보이기 때문에 병렬처리 등이 필요함.
- 동시에 training 에서 데이터 샘플의 순서를 섞어주는 등의 shffle 기능들도 필요한데,
- PyTorch는 이를
DataLoader
를 통해 제공해줌.
DataLoader는 Dataset 을 제공해 줄 경우,
해당 Dataset을 통해 이루어지는 데이터 로딩을
병렬처리 및 shuffle, minibatch로 나누는 기능 등을 추가하여 수행하도록 도와줌.
일반적으로, DataLoader를 구현할 필요는 없고, Task별 rawdata에 대한 Dataset 클래스를 만들면 됨.
Dataset으로부터 DataLoader 만들기.
data_loader = DataLoader(
dataset, # torch.utils.data.Dataset의 instance
batch_size, # batch의 샘플수
shffule, # boolean, 셔플링을 할지 여부(순서를 랜덤하게)
num_workers, # 데이터로딩에 사용되는 sub-process의 수 (CPU의 core수를 넘으면 안됨.)
pin_memory, # boolean, GPU memory 영역을 예약할지 여부(pin).
drop_last, # boolean, 마지막 batch가 샘플의 수가 맞지 않을 경우 dorp할지 여부.
)
Example of DataLoader
from torch.utils.data import DataLoader
data_loader = DataLoader(
dataset,
batch_size = 4,
shuffle = True,
)
for batch_idx, (data, labels) in enumerate(data_loader):
print(f'{batch_idx=}')
print(f'{data.shape} | {data=}')
print(f'{labels.shape} | {labels=}')
# training ...
'Python' 카테고리의 다른 글
[PyTorch] CustomANN Example: From Celsius to Fahrenheit (0) | 2024.04.12 |
---|---|
[PyTorch] torch.nn.init (0) | 2024.04.11 |
[Python] pathlib.Path 사용하기. (0) | 2024.03.31 |
[DL] Tensor: Random Tensor 만들기 (NumPy, PyTorch) (0) | 2024.03.29 |
[DL] Define and Run vs. Define by Run (0) | 2024.03.28 |