본문 바로가기
Python

[PyTorch] Dataset and DataLoader

by ds31x 2024. 4. 9.

Dataset 이란

  • PyTorch 의 tensor
  • 학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,

raw-data로부터 PyTorch의 기본데이터형tensor를 얻게 해주는 역할을 수행하는 class임.

일반적으로 task에 따라
다양한 형태의 raw-data가 있기 때문에,
각 task별로 Custom Dataset을 만드는 경우가 잦음.


Custom Dataset 만들기

torch.util.data 모듈의 Dataset을 상속하고,
다음의 methods를 overriding해야만 함.

  1. __init__(self) :
    • Dataset 인스턴스에 대한 생성자로 데이터셋에 대한 초기화를 담당.
    • raw-data에 따라 parameters를 자유롭게 추가할 수 있음.
  2. __len__(self) :
    • Dataset 인스턴스 내에 있는 샘플 갯수를 반환하도록 구현.
  3. __getitem__(self, idx) :
    • argument로 넘어오는 idx 에 해당하는 샘플을 반환하도록 구현.
    • dataset[idx] 등으로 Dataset인스턴스인 dataset에 대해 idx에 해당하는 데이터셋 샘플에 접근할 때 호출됨.
    • 하나의 샘플은 input feature data와 label을 묶어서(tuple) 반환.
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data).float()
        self.labels = torch.tensor(labels).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

x = [0.5,  14.0, 15.0, 28.0, 11.0,  8.0,  3.0, -4.0,  6.0, 13.0, 21.0]
y = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]

dataset = CustomDataset(x, y)

 

 

Dataset은 iterable이기 때문에, 다음과 같이 각 샘플을 확인할 수 있음.

i = iter(dataset)
print(next(i))
# print(i.__next__())

 

단, collections.abc 의 Iterable을 직접 상속하고 있지는않음
issubclass(Dataset, Itearble)로 확인시 False.

 

2024.04.15 - [분류 전체보기] - [Python] collections.abc

 

[Python] collections.abc

2023.10.06 - [Pages] - [Python] Collections collections.abc 와 Python의 DataStructure. Python의 Data structure는 실제적으로 collections.abc 라는 abstract base class (abc) mechanism를 통한 hierarchy로 구성된다. 일반적으로 list, tuple

ds31x.tistory.com

https://dsaint31.tistory.com/501

 

[Python] Iterable and Iterator, plus Generator

Iterable for 문에서 in 뒤에 위치하여 iterate (반복, 순회)가 가능한 object를 가르킴. __iter__() 라는 special method를 구현하고 있으며, 이를 통해 자신에 대한 iterator object를 반환할 수 있음. __iter__() special

dsaint31.tistory.com

 


DataLoader

Dataset을 통해 데이터 로딩을 수행하는 Class.

Dataset과 달리 사용법만 익혀도 일반적인 경우에는 충분함.

 

Dataset으로부터 실제 Training Loop등에 training (mini)batch를 묶어서 효율적으로 제공해주는 역할을 수행*

  • raw-data로부터 tensor 를 얻어내는 과정을 data loading 이라고 부르는데,
  • 일반적으로 storage에서 데이터를 읽어들이는 loading은 상대적으로 느린 속도를 보이기 때문에 병렬처리 등이 필요함.
  • 동시에 training 에서 데이터 샘플의 순서를 섞어주는 등의 shffle 기능들도 필요한데,
  • PyTorch는 이를 DataLoader를 통해 제공해줌.

DataLoader는 Dataset 을 제공해 줄 경우,
해당 Dataset을 통해 이루어지는 데이터 로딩을
병렬처리 및 shuffle, minibatch로 나누는 기능 등을 추가하여 수행하도록 도와줌.

일반적으로, DataLoader를 구현할 필요는 없고, Task별 rawdata에 대한 Dataset 클래스를 만들면 됨.


Dataset으로부터 DataLoader 만들기.

data_loader = DataLoader(
    dataset,     # torch.utils.data.Dataset의 instance
    batch_size,  # batch의 샘플수
    shffule,     # boolean, 셔플링을 할지 여부(순서를 랜덤하게)
    num_workers, # 데이터로딩에 사용되는 sub-process의 수 (CPU의 core수를 넘으면 안됨.)
    pin_memory,  # boolean, GPU memory 영역을 예약할지 여부(pin).
    drop_last,   # boolean, 마지막 batch가 샘플의 수가 맞지 않을 경우 dorp할지 여부.
    )

Example of DataLoader

from torch.utils.data import DataLoader

data_loader = DataLoader(
    dataset,
    batch_size = 4,
    shuffle = True,
)

for batch_idx, (data, labels) in enumerate(data_loader):
    print(f'{batch_idx=}')
    print(f'{data.shape} | {data=}')
    print(f'{labels.shape} | {labels=}')

    # training ...