본문 바로가기
목차
Python

[PyTorch] Dataset and DataLoader

by ds31x 2024. 4. 9.
728x90
반응형

1. Dataset 이란 : 

  • PyTorch 의 tensor
  • 학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,
  • raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게 해주는 역할을 수행.
Dataset의 가장 핵심적 역할은
데이터에 대한 index 기반 접근 인터페이스를 제공하는 것임.


참고: 전처리는 transforms 와 transform, target_transform: 

이미지의 전처리는 torchvision.trnasforms 모듈 또는 torchvision.transforms.v2 모듈의 transform 클래스들의 객체를 이용한다.차이가 있는데,

  • transform은 input feature vector(or image)에만 적용.
  • target_transform은 target(or label)에만 적용.
  • transforms는 input과 target 에 둘 다 적용됨(v2부터 사용됨.)

VOC dataset에선 transforms가 지원이 되나 COCO 를 포함한 대부분의 경우는 transform과 target_transform 으로 분리된 방식만을 지원하고 있음. 단, v2를 쓴다면 transforms를 사용하는게 좋을 듯.

2025.01.12 - [Python] - [PyTorch] torchvision.transforms 사용법 - transform이란?

 

[PyTorch] torchvision.transforms 사용법 - transforms란?

PyTorch의 torchvision.transforms:이미지 전처리와 데이터 증강을 위한 도구torchvision.transforms는PyTorch에서 제공하는 이미지 전처리 및 data augmentation을 위한 module.이 모듈은 이미지 데이터를 이용한 딥러

ds31x.tistory.com

 

 


DataLoader와 같이 동작하여, PyTorch의 기본데이터형tensor를 얻게 해주는 역할을 수행하는 class임.

 

일반적으로 task에 따라
다양한 형태의 raw-data가 있기 때문에,
각 task별로 Custom Dataset을 만드는 경우가 잦음.

 

사실, DataLoadercollate_fn에 지정된 함수(collate_fn=None인 경우 동작하는 기본함수)에서 tensor로의 변환이 이루어지므로 Dataset에서 반드시 tensor로 변환하지 않아도 됨

(주의 : ndarray까지는 되나, PIL Image는 기본 collate_fn이 처리 못함) .

 

하지만 가급적 Dataset에서 일찍이 처리해주는게 보다 권장됨.


1-1. Custom Dataset 만들기 : 

torch.utils.data 모듈의 Dataset을 상속하고,
다음의 methods를 overriding해야만 함.

  1. __init__(self) :
    • Dataset 인스턴스에 대한 생성자 로 데이터셋에 대한 초기화를 담당.
    • raw-data에 따라 parameters를 자유롭게 추가할 수 있음.
  2. __len__(self) :
    • Dataset 인스턴스 내에 있는 샘플 갯수를 반환하도록 구현.
  3. __getitem__(self, idx) :
    • argument로 넘어오는 idx 에 해당하는 샘플을 반환하도록 구현.
    • dataset[idx] 등으로 Dataset인스턴스인 dataset에 대해 idx에 해당하는 데이터셋 샘플에 접근할 때 호출됨.
    • 하나의 샘플 인스턴스의 input feature data와 label을 묶어서(tuple) 반환 하는 것이 일반적.
      • dict 로 넘기는 것도 잘 동작함.
    •  tuple과 dict의 top level item들은 보통 tensor 객체 로 처리하는게 좋음.
      • 엄밀히 말하면, DataLoader의 기본 collate_fn에 설정된 collate function 에 의해 이들 item들을 적절히 batch로 묶고 동시에 tensor로 변환이 이루어지기 때문에 꼭 Dataset에서 tensor로 변환하지 않아도 됨.
      • 항상 tensor로 변환되는 건 아니고 item의 타입에 따라 다름:
        • 숫자 데이터: 묶어서 tensor로 변환
        • 문자열(str): 그대로 유지하고 tuple로 묶음. 
        • list/tuple: 각 위치별로 독립적인 tensor로 변환 (주의 필요) 하고 이를 list로 묶음
        • 이에 대해선 "같이보면 좋은 자료들"에서의 "dl_default_collate_fn.ipynb" 및 "아래 글" 참고.
        • 2025.04.26 - [Python] - [DL] collate_fn - PyTorch
      • 하지만, 가급적 tensor로 변환시켜서
        model의 입력되는 feature vector에 해당하는 tensor 객체와 label에 해당하는 tensor 객체를 반환하도록 하는 것을 권장함.
    • 샘플 인스턴스 단위의 pre-processing 이 필요한 경우, transform 객체를 생성자에서 설정하고 이를 통해 __getitem__ 메서드에서 처리하는게 일반적임.

1-2. Example : 

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(
        self,
        data,
        labels,
        transform=None,         # sample 전용: sample -> sample
        target_transform=None,  # target 전용: target -> target
        transforms=None,        # joint: (sample, target) -> (sample, target)
    ):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

        self.transform = transform
        self.target_transform = target_transform
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        target = self.labels[idx]

        # 1) joint transforms가 있으면 (sample, target) 함께 변환
        if self.transforms is not None:
            sample, target = self.transforms(sample, target)
            return sample, target

        # 2) 아니면 각각 따로 변환
        if self.transform is not None:
            sample = self.transform(sample)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target


# --------------------
# 사용 예시
# --------------------
x = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]
y = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]

# (예) sample만 변환
def sample_transform(s):
    return (s - 10.0) / 5.0

# (예) target만 변환
def target_transform(t):
    return t / 100.0

# (예) sample, target을 동시에 변환 (예: 둘 다 스케일링)
def joint_transforms(s, t):
    s2 = (s - 10.0) / 5.0
    t2 = t / 100.0
    return s2, t2

dataset_sep = CustomDataset(x, y, transform=sample_transform, target_transform=target_transform)
dataset_joint = CustomDataset(x, y, transforms=joint_transforms)

print(dataset_sep[0])    # (sample 변환, target 변환)
print(dataset_joint[0])  # (joint 변환)

 

Datasetiterable로도 다룰 수 때문에, 다음과 같이 각 샘플을 확인할 수 있음.

i = iter(dataset)
print(next(i))
# print(i.__next__())

 

단, collections.abc 의 Iterable을 직접 상속하고 있지는않음
issubclass(Dataset, Itearble)로 확인시 False.

 

2024.04.15 - [분류 전체보기] - [Python] collections.abc

 

[Python] collections.abc

2023.10.06 - [Pages] - [Python] Collections collections.abc 와 Python의 DataStructure. Python의 Data structure는 실제적으로 collections.abc 라는 abstract base class (abc) mechanism를 통한 hierarchy로 구성된다. 일반적으로 list, tuple

ds31x.tistory.com

https://dsaint31.tistory.com/501

 

[Python] Iterable and Iterator, plus Generator

Iterable and Iteartor, and GeneratorIterable for 문에서 in 뒤에 위치하여 iterate (반복, 순회)가 가능한 object를 가르킴.__iter__() 라는 special method를 구현하고 있으며, 이를 통해 자신에 대한 iterator object를 반환

dsaint31.tistory.com

 


2. DataLoader : 

Dataset을 통해
데이터 로딩을 수행하는 Class.

 

Dataset과 달리 사용법만 익혀도 일반적인 경우에는 충분함: CPU에서 동작.

 

Dataset으로부터 실제 Training Loop등에 training (mini)batch를 묶어서 효율적으로 제공해주는 역할을 수행*

  • raw-data로부터 tensor 를 얻어내는 과정을 data loading 이라고 부르는데,
  • 일반적으로 storage에서 데이터를 읽어들이는 loading은 상대적으로 느린 속도를 보이기 때문에 병렬처리 등이 필요함.
  • 동시에 training 에서 데이터 샘플의 순서를 섞어주는 등의 shuffle 기능들도 필요한데,
  • PyTorch는 이를 DataLoader를 통해 제공해줌.

DataLoader는 Dataset 을 제공해 줄 경우,
해당 Dataset을 통해 이루어지는 데이터 로딩을
병렬처리shuffle, minibatch로 나누는 기능 등을 추가하여 수행하도록 도와줌.

일반적으로, DataLoader를 구현할 필요는 없고, Task별 rawdata에 대한 Dataset 클래스를 만들면 됨.


2-1. Dataset으로부터 DataLoader 만들기 : 

data_loader = DataLoader(
    dataset,     # torch.utils.data.Dataset의 instance
    batch_size,  # batch의 샘플수
    shuffle,     # boolean, 셔플링을 할지 여부(순서를 랜덤하게)
    num_workers, # 데이터로딩에 사용되는 sub-process의 수 (CPU의 core수를 넘으면 안됨.)
    pin_memory,  # boolean, GPU memory 영역을 예약할지 여부(pin).
    drop_last,   # boolean, 마지막 batch가 샘플의 수가 맞지 않을 경우 dorp할지 여부.
    collate_fn,  # callable, 샘플 리스트를 배치로 변환하는 함수
                 # None: 기본 collate_fn 사용 (텐서 자동 스택, 딕셔너리 처리 등)
                 # 커스텀 함수: 가변 길이 패딩, 배치 단위 전처리 등에 사용
    )

 

  • collate_fn 에 할당되는 collate function은 batch를 argument로 전달받아, 이에 대해 처리를 하고 다시 batch를 반환함.
    • batch 단위로 필요한 전처리 를 구현하는데 사용됨: transform이 instance 단위의 전처리인 것과 차이점을 가짐.
    • 기본 collate_fn의 경우(default_collate), __getitem__ 메서드의 반환 구조(tuple, list, dict 등)를 보존하면서, 각 index 또는 key 별 요소들을 모아 타입에 따라 tensor로 변환하거나 재귀적으로 묶어서 batch를 구성함.
      • tupledict의 item을 보통 stack 처리하여 batch 차원이 앞에 놓이는 최종 batch 텐서를 생성함.
      •  자세한 건 "같이보면 좋은 자료들"에서의 "dl_default_collate_fn.ipynb"참고.
      • 기본적으로 torch.utils.data._utils.collate.default_collate 함수가 collate_fn=None 인 경우 수행됨.
      • 단, detection이나 segment 같이 stack이 먹히지 않는 경우(묶을 요소들이 다른 shape를 가질 경우)에는 default_collate를 쓰면 안됨. 아래의 "Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계" 글 참고.
    • feature vector, label 의 tuple이 반환되도록 Dataset이 구현되는 것이 가장 쉽게 이용가능한 형태임.
      • 사실, 키 또는 index 위치별로 batch 크기에 맞춰 tensor를 각각 만들고 이들을 list로 묶어서 반환.
      • batch크기로 묶을 때, 각각의 키 또는 index의 데이터들이 숫자인 경우 tensor로 묶고, 문자열인 경우 tuple로 묶어냄.

2025.12.15 - [ML] - Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

 

Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

Detection task를 위한 collate_fnDetection 테스크의 경우, DataLoader에서 collate_fn을 일반적으로 다음과 같이 변경해줘야 함.def collate_fn(batch): # batch: [(img, target), (img, target), ...] imgs, targets = zip(*batch) return lis

ds31x.tistory.com


2-2. batch 정보를 이용한 전처리: collate_fn

기본 signature는 다음과 같음:

def collate_fn(batch: list[Any]) -> Any:
    ...

parameter 인 batch는 Dataset 객체의 __getitem__(idx)의 반환값들의 list임.

batch = [
    dataset[0],
    dataset[1],
    dataset[2],
    ...
]
더보기

classification의 경우엔 다음의 signautre가 일반적

  • __getitem__ -> (image, label)
  • image: Tensor
  • label: int
def collate_fn(
    batch: list[tuple[torch.Tensor, int]]
) -> tuple[torch.Tensor, torch.Tensor]:
    ...

 

detection / segmentation의 경우는 다음과 같음:

def collate_fn(
    batch: list[tuple[torch.Tensor, dict]]
) -> tuple[list[torch.Tensor], list[dict]]:
    ...

 

이 경우 collate_fn을 다음으로 지정해야 한다.

def collate_fn(batch): 
    # batch: [(img, target), (img, target), ...]
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)
    
loader = DataLoader(
    train_ds, 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collate_fn,
)

2025.12.15 - [ML] - Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

 

Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

Detection task를 위한 collate_fnDetection 테스크의 경우, DataLoader에서 collate_fn을 일반적으로 다음과 같이 변경해줘야 함.def collate_fn(batch): # batch: [(img, target), (img, target), ...] imgs, targets = zip(*batch) return lis

ds31x.tistory.com


collate function의 기본 동작의 자세한 건 다음을 참고:

간략하게 살펴보면 다음과 같음:

2-2-1. Dataset이 tuple로 반환한 경우:

__getitem__ -> (image, label) # image [C,H,W], label (int)

batch 구성 후:

(
    Tensor[N, C, H, W],   # image
    Tensor[N]             # label
)

 

2-2-2. Dataset이 dict를 반환한 경우:

__getitem__ -> {"img": image, "label": label}

batch 구성 후:

{
    "img": Tensor[N, C, H, W],
    "label": Tensor[N]
}

 

2-2-3. Dataset이 list를 반환한 경우

__getitem__ -> [image, label]

batch 구성 후:

[
    Tensor[N, C, H, W],
    Tensor[N]
]

 

주로 다음의 경우에 많이 collate_fn이 사용됨.

  • batch 단위 normalization (Batch Normalization Layer와 유사.)
  • 가변 길이의 데이터를 batch내의 최대 길이로 padding 시켜, 같은 길이로 만드는 전처리.
  • batch 에서 outliar 제거.
  • batch 통계치 계산.
  • 동적으로 batch 구성 변경.

2-3. Example of DataLoader : 

from torch.utils.data import DataLoader

data_loader = DataLoader(
    dataset,
    batch_size = 4,
    shuffle = True,
)

for batch_idx, (data, labels) in enumerate(data_loader):
    print(f'{batch_idx=}')
    print(f'{data.shape} | {data=}')
    print(f'{labels.shape} | {labels=}')

    # training ...

3. Example

Dataset과 Dataloader를 직접 만들어본 간단한 예제 gist임.

ImageFolder처럼 동작하는 형태를 취함

https://gist.github.com/dsaint31x/b0ecd77af41c5666ee674d47d9a44bf6

 

dl_custom_dataset_like_imagefolder.ipynb

dl_custom_dataset_like_imagefolder.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 


같이보면 좋은 자료들

2025.06.17 - [Python] - torchvision.datasets.ImageFolder 사용하기.

 

torchvision.datasets.ImageFolder 사용하기.

torchvision.datasets.ImageFolder는 PyTorch에서 이미지 분류 작업을 위한 데이터셋을 쉽게 로드할 수 있게 해주는 클래스임. https://docs.pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html ImageFolder — T

ds31x.tistory.com

2025.01.12 - [Python] - [PyTorch] torchvision.transform 사용법

 

[PyTorch] torchvision.transforms 사용법 - transforms란?

PyTorch의 torchvision.transforms:이미지 전처리와 데이터 증강을 위한 도구torchvision.transforms는PyTorch에서 제공하는 이미지 전처리 및 data augmentation을 위한 module.이 모듈은 이미지 데이터를 이용한 딥러

ds31x.tistory.com

https://gist.github.com/dsaint31x/617a4d76a32c7afda45058095f6b7e56

 

dl_default_collate_fn.ipynb

dl_default_collate_fn.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 

2025.04.26 - [Python] - [DL] collate_fn - PyTorch

 

[DL] default collate_fn - PyTorch

collate_fn=None 의 collate function 동작PyTorch에서 DataLoader에서 collate_fn=None인 경우, torch.utils.data._utils.collate.default_collate 함수가 기본으로 사용됨.2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader [PyTorch] Dataset

ds31x.tistory.com

2025.12.15 - [ML] - Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

 

Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

Detection task를 위한 collate_fnDetection 테스크의 경우, DataLoader에서 collate_fn을 일반적으로 다음과 같이 변경해줘야 함.def collate_fn(batch): # batch: [(img, target), (img, target), ...] imgs, targets = zip(*batch) return lis

ds31x.tistory.com


 

728x90