
torch.tensor(batch)
collate_fn=None 의 collate function 동작
PyTorch에서 DataLoader에서 collate_fn=None인 경우, torch.utils.data._utils.collate.default_collate 함수가 기본으로 사용됨.
2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader
[PyTorch] Dataset and DataLoader
1. Dataset 이란 : PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게
ds31x.tistory.com
PyTorch DataLoader의 기본 collate_fn은 다음과 같이 Element's dtype 에 따라 다르게 동작:
1. 처리 기준
- dataset.__getitem__(idx)의 반환값 구조를 기준으로 batch를 구성.
- 대표적 __getitem__(idx)의 반환값 구조: (image, label), {"img":..., "label":...}, [x, y, z]
- batch의 각 샘플에서 같은 위치에 있는 요소들을 모아 처리함
- 즉, tuple의 top-level item 단위로 타입을 검사하고 해당 타입에 따라 collate규칙을 적용.
2. 처리 순서
- batch(= __getitem__ 반환값들의 list객)의 첫 원소 타입을 확인
- 해당 타입에 맞는 collate 규칙을 선택 .
- 컨테이너(Sequence/Mapping)면 재귀적으로 하위 요소에 대해 반복 처리
- 처리 결과는 입력 구조를 보존하려고 하되,
- Sequence의 경우 결과 컨테이너는 보통 list 로 나오는 경향이 있음(재귀 처리 결과)
3. 타입별 처리(간략버전)
torch.Tensor:torch.stack으로 쌓음numpy.ndarray: numpy array를 tensor로 변환 후torch.stack- 숫자 (
int,float): batch를 그대로 tensor로 변환 (stack 아님에 유의) str: tensor가 아닌 문자열을 싸고 있던 container 객체로 묶임- Sequence (
list,tuple): 각 위치별로 요소를 재귀적으로 처리 후 list로 묶음. - Mapping (
dict등): key 별로 요소를 재귀적으로 처리. - Custom:
default_collate_fn_map에 등록된 function 으로 처리
batch 객체는 기본적으로 list 객체임.
단, Dataset이 dict로 item을 반환하는 경우만 dict 객체가 됨.
이들에 대한 ipynb 파일은 다음과 같음:
https://gist.github.com/dsaint31x/617a4d76a32c7afda45058095f6b7e56
dl_default_collate_fn.ipynb
dl_default_collate_fn.ipynb. GitHub Gist: instantly share code, notes, and snippets.
gist.github.com
1. tensor 데이터
- 이미
tensor인 경우, 여러 샘플을 새로운 차원(batch)을 추가하여 stack. - 예: 세 개의
shape=(3, 4)tensors →shape=(3, 3, 4)tensor
주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 tensor 객체인 경우에 해당.
주로 feature vector 데이터에서 적용되는 방식임.
torch.stack(batch, dim=0)
2. numpy array
ndarray객체 자체가tensor로 변환한 후 batch 차원이 추가하여 stack.
주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 ndarray 객체인 경우에 해당.
주로 feature vector 데이터에서 적용되는 방식이며, 사실상 tensor 데이터인 경우와 동일.
torch.as_tensor(ndarray) → torch.stack
3. 숫자형 데이터 (int, float 등)
- 여러 샘플을 하나의 tensor로 변환 후 batch 차원으로 stack.
- 예:
1, 2, 3→tensor([1, 2, 3])
주로 Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 숫자형 데이터인 경우에 해당.
주로 label 데이터에서 흔히 적용되는 방식.
batch = [dataset[0], dataset[1], dataset[2], ...]
# batch[i]는 모두 같은 구조여야 함
4. 문자열 (str)
- 다른 데이터형들과 달리, tuple로 묶임: 참고로 Dataset에서 문자열만 단독으로 반환하는 경우는 list로 묶임- batch는 list 객체임.
- 예:
'a', 'b', 'c'→('a', 'b', 'c')
주로, Dataset의 __getitem__ 에서 반환되는 tuple 또는 list의 top-level item이 str 객체인 경우에 해당.
["a", "b", "c"]
그리 많이 사용되지는 않음:
ML에서는 str 객체가 해당하는 embedding representation( seq. of embedding vector)로 변경되는게 일반적임.
5. list나 tuple
- position-wise unzip을 수행하고 나서 각 item별로 재귀적으로 처리됨.
- 각 위치의 요소들끼리 batch 크기로 묶임: top-level data타입에 따라 tensor 혹은 tuple 객체가 됨.
- 예:
(1, 2, 'a'), (3, 4, 'b')→(tensor([1, 3]),tensor([2, 4]), ['a', 'b'])
주로,
Dataset의 __getitem__ 에서 반환되는 tuple 또는 list 인 경우와
또는 반환되는 tupe과 list의 top-level item이 다시 list 또는 tuple 객체인 경우에 해당: nested list or nested tuple
list나 tuple 자체가 tensor가 되는 것이 아니라, 각 위치의 요소들이 별도의 tensor로 처리되고,
이들 tensor가 최종적으로는 list로 묶임.
[(a1, b1), (a2, b2), (a3, b3)]
→ [
collate([a1, a2, a3]),
collate([b1, b2, b3])
]
6. dictionary
- key(키)별로 값들을 묶습니다
- 예:
{'x': 1, 'y': 'a'}, {'x': 2, 'y': 'b'}→{'x': tensor([1, 2]), 'y': ['a', 'b']}
주로,
Dataset의 __getitem__ 에서 dict로 반환되는 경우에 해당됨.
list와 tuple에서 index 로 묶이던 것이 key로 수행된다고 생각하면 됨.
즉, 각 key의 요소들이 별도로 처리됨.
[
{"x": a1, "y": b1},
{"x": a2, "y": b2}
]
→ {
"x": collate([a1, a2]),
"y": collate([b1, b2])
}'Python' 카테고리의 다른 글
| [PySide] QtCore.QSettings 사용법 (0) | 2025.05.12 |
|---|---|
| [Programming] SOLID 원칙 (0) | 2025.04.28 |
| [Py] importlib.metadata: Package 정보 확인 (0) | 2025.04.23 |
| [Programming] Control Flow 와 Control Structure (1) | 2025.04.23 |
| [Py] import 의 종류. (0) | 2025.04.18 |