본문 바로가기
목차
Python

[DL] default collate_fn - PyTorch

by ds31x 2025. 4. 26.
728x90
반응형

torch.tensor(batch)

collate_fn=None 의 collate function 동작

PyTorch에서 DataLoader에서 collate_fn=None인 경우,
torch.utils.data._utils.collate.default_collate 함수가 기본으로 사용됨.

2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader

 

[PyTorch] Dataset and DataLoader

1. Dataset 이란 : PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게

ds31x.tistory.com


PyTorch DataLoader의 기본 collate_fn은 다음과 같이 Element's dtype 에 따라 다르게 동작:

1. 처리 기준

  • dataset.__getitem__(idx)반환값 구조를 기준으로 batch를 구성.
    • 대표적 __getitem__(idx)의 반환값 구조: (image, label), {"img":..., "label":...}, [x, y, z]
  • batch의 각 샘플에서 같은 위치에 있는 요소들을 모아 처리함
  • 즉, tuple의 top-level item 단위로 타입을 검사하고 해당 타입에 따라 collate규칙을 적용.

2. 처리 순서

    • batch(= __getitem__ 반환값들의 list객)의 첫 원소 타입을 확인
    • 해당 타입에 맞는 collate 규칙을 선택 .
    • 컨테이너(Sequence/Mapping)면 재귀적으로 하위 요소에 대해 반복 처리
    • 처리 결과는 입력 구조를 보존하려고 하되,
      • Sequence의 경우 결과 컨테이너는 보통 list 로 나오는 경향이 있음(재귀 처리 결과)

3. 타입별 처리(간략버전)

  • torch.Tensor: torch.stack으로 쌓음
  • numpy.ndarray: numpy array를 tensor로 변환 후 torch.stack
  • 숫자 (int, float): batch를 그대로 tensor로 변환 (stack 아님에 유의)
  • str:  tensor가 아닌 문자열을 싸고 있던 container 객체로 묶임
  • Sequence (list, tuple): 각 위치별로 요소를 재귀적으로 처리 후 list로 묶음.
  • Mapping (dict 등): key 별로 요소를 재귀적으로 처리.
  • Custom: default_collate_fn_map에 등록된 function 으로 처리

batch 객체는 기본적으로 list 객체임.

단, Dataset이 dict로 item을 반환하는 경우만 dict 객체가 됨.

 

이들에 대한 ipynb 파일은 다음과 같음:

https://gist.github.com/dsaint31x/617a4d76a32c7afda45058095f6b7e56

 

dl_default_collate_fn.ipynb

dl_default_collate_fn.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com


1. tensor 데이터

  • 이미 tensor인 경우, 여러 샘플을 새로운 차원(batch)을 추가하여 stack.
  • 예: 세 개의 shape=(3, 4) tensors → shape=(3, 3, 4) tensor

주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 tensor 객체인 경우에 해당.

 

주로 feature vector 데이터에서 적용되는 방식임.

torch.stack(batch, dim=0)

2. numpy array

  • ndarray객체 자체가 tensor로 변환한 후 batch 차원이 추가하여 stack.

주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 ndarray 객체인 경우에 해당.

 

주로 feature vector 데이터에서 적용되는 방식이며, 사실상 tensor 데이터인 경우와 동일.

torch.as_tensor(ndarray) → torch.stack

3. 숫자형 데이터 (int, float 등)

  • 여러 샘플을 하나의 tensor로 변환 후 batch 차원으로 stack.
  • 예:1, 2, 3tensor([1, 2, 3])

주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 숫자형 데이터인 경우에 해당.

 

주로 label 데이터에서 흔히 적용되는 방식.

batch = [dataset[0], dataset[1], dataset[2], ...]
# batch[i]는 모두 같은 구조여야 함

4. 문자열 (str)

  • 다른 데이터형들과 달리, tuple로 묶임: 참고로 Dataset에서 문자열만 단독으로 반환하는 경우는 list로 묶임- batch는 list 객체임.
  • 예:'a', 'b', 'c'('a', 'b', 'c')

주로, Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 str 객체인 경우에 해당.

["a", "b", "c"]

 

그리 많이 사용되지는 않음:

ML에서는 str 객체가 해당하는 embedding representation( seq. of embedding vector)로 변경되는게 일반적임.


5. list나 tuple

  • position-wise unzip을 수행하고 나서 각 item별로 재귀적으로 처리됨.
  • 각 위치의 요소들끼리 batch 크기로 묶임: top-level data타입에 따라 tensor 혹은 tuple 객체가 됨.
  • 예:(1, 2, 'a'), (3, 4, 'b')(tensor([1, 3]),tensor([2, 4]), ['a', 'b'])

주로,

Dataset의 __getitem__ 에서 반환되는 tuple 또는 list 인 경우와

또는 반환되는 tupelist의 top-level item이 다시 list 또는 tuple 객체인 경우에 해당: nested list or nested tuple

 

listtuple 자체가 tensor가 되는 것이 아니라, 각 위치의 요소들이 별도의 tensor로 처리되고,
이들 tensor가 최종적으로는 list로 묶임.

[(a1, b1), (a2, b2), (a3, b3)]
→ [
    collate([a1, a2, a3]),
    collate([b1, b2, b3])
]

6. dictionary

  • key(키)별로 값들을 묶습니다
  • 예:{'x': 1, 'y': 'a'}, {'x': 2, 'y': 'b'}{'x': tensor([1, 2]), 'y': ['a', 'b']}

주로,

Dataset의 __getitem__ 에서 dict로 반환되는 경우에 해당됨.

 

listtuple에서 index 로 묶이던 것이 key로 수행된다고 생각하면 됨.

 

즉, 각 key의 요소들이 별도로 처리됨.

[
  {"x": a1, "y": b1},
  {"x": a2, "y": b2}
]
→ {
    "x": collate([a1, a2]),
    "y": collate([b1, b2])
}
728x90

'Python' 카테고리의 다른 글

[PySide] QtCore.QSettings 사용법  (0) 2025.05.12
[Programming] SOLID 원칙  (0) 2025.04.28
[Py] importlib.metadata: Package 정보 확인  (0) 2025.04.23
[Programming] Control Flow 와 Control Structure  (1) 2025.04.23
[Py] import 의 종류.  (0) 2025.04.18