본문 바로가기
목차
ML

Detection Task에서 collate_fn의 입·출력 구조와 default_collate의 한계

by ds31x 2025. 12. 15.
728x90
반응형

Detection task를 위한 collate_fn

Detection 테스크의 경우, DataLoader에서 collate_fn을 일반적으로 다음과 같이 변경해줘야 함.

def collate_fn(batch): 
    # batch: [(img, target), (img, target), ...]
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)


loader = DataLoader(
    train_ds, 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collate_fn,
)

detection task의 경우, 각 image 별로 가지고 있는 object의 수가 다를 수 있음.

  • 그 결과 targets["boxes"] 의 첫번째 차원(바운딩박스의 갯수)이 샘플(img)마다 달라지며,
  • 이는 default_collate 가 배치 처리를 위해 내부적으로 사용하는 tensor.stack 호출시 에러로 이어짐.

 

*batch 와 같은 astrisk 연산자를 이용한 unpacking 에 대한 자세한건 다음을 참고:


Detection task에서 torch가 일반적으로 가정하는 collate_fn의 입출력 구성.

 

다음의 gist 파일을 실행하면서 같이 살펴볼 것:

https://gist.github.com/dsaint31x/94202acc63f0867ff9c9a83bd476021e

 

dl_object_detection_collate_fn.ipynb

dl_object_detection_collate_fn.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com


detection의 collate_fn 입력:

Dataset에서 __getitem__(idx)batch_size만큼 호출하여 모은 결과물.

Detection dataset의 __getitem__(idx) 에서 (image, target) 형태를 반환한다면

collate_fn에서 argument(인자)는 다음과 같음:

batch= [ (img1, target1), (img2, target2), ... ]

 

여기서 target은 다음의 키와 값을 가지는 dict 객체임:

target = {
    # (필수) 바운딩박스: float Tensor, shape [N, 4]
    # 보통 포맷은 XYXY = [x1, y1, x2, y2]
    "boxes":  Tensor[N, 4],

    # (필수) 클래스 라벨: int64 Tensor, shape [N]
    # 각 박스와 1:1 대응
    "labels": Tensor[N],

    # (선택) 이미지 식별자: int64 Tensor, shape [1]
    "image_id": Tensor[1],

    # (선택) 박스 면적: float Tensor, shape [N]
    "area":   Tensor[N],

    # (선택) COCO 스타일 crowd 여부: uint8 또는 bool Tensor, shape [N]
    "iscrowd": Tensor[N],
}
  • N: 해당 이미지 하나에 포함된 object(바운딩 박스)의 개수
  • N은 이미지마다 다를 수 있음
  • "boxes""labels"대부분의 torchvision detection 학습 코드에서 사실상 필수처럼 사용됨.
  • "image_id", "area", "iscrowd"는 COCO-style 평가/학습 루프에서 쓰이는 경우가 많아 자주 등장함.

detection의 collate_fn 출력:

collate_fn 의 callable객체를 통해 위 batch를 모델이 받기 좋은 형태로 “묶어서” 반환.

torchvision 계열 detection 모델에서 흔히 쓰는 형태는:

  • images: List[Tensor] (배치 크기만큼, 각 원소는 이미지 텐서)
  • targets: List[Dict [str, Tensor]] (배치 크기만큼, 각 원소는 타깃 dict)
(images, targets) = collate_fn(batch)
# images  : List[Tensor]             -> [img1, img2, ...]
# targets : List[Dict[str, Tensor]]  -> [target1, target2, ...]

여기서 images 와 targets의 구성은 다음과 같음:

images = List[Tensor]
         ├─ images[0]: Tensor[C, H, W]
         ├─ images[1]: Tensor[C, H, W]
         ├─ ..
         
targets = List[Dict[str, Tensor]]
           ├─ targets[0]["boxes"]   : Tensor[N0, 4]
           ├─ targets[0]["labels"]  : Tensor[N0]
           ├─ targets[0]["image_id"]: Tensor[N0]
           ├─ targets[0]["area"]    : Tensor[N0]
           ├─ targets[0]["iscrowd"] : Tensor[N0]
           ├─ targets[1]["boxes"]   : Tensor[N1, 4]
           ├─ ..

 

특히 targets의 각 요소인 dict는 일반적으로 다음 구조를 가짐 (= targets[i]가 가지는 표준적인 키와 shape)

 

targets[i] = {
    # (필수) 바운딩 박스 좌표
    # float Tensor, shape [Ni, 4]
    # 포맷: [x1, y1, x2, y2] (XYXY)
    "boxes": Tensor[Ni, 4],

    # (필수) 클래스 라벨
    # int64 Tensor, shape [Ni]
    # boxes와 1:1 대응
    "labels": Tensor[Ni],

    # (선택) 이미지 ID
    # int64 Tensor, shape [1]
    "image_id": Tensor[1],

    # (선택) 각 박스의 면적
    # float Tensor, shape [Ni]
    "area": Tensor[Ni],

    # (선택) crowd 여부 (COCO 스타일)
    # uint8 또는 bool Tensor, shape [Ni]
    "iscrowd": Tensor[Ni],
}
  • Ni i번째 이미지에 포함된 object 개수
  • Ni는 이미지마다 서로 다를 수 있음

따라서 targets는 다음과 같은 형태가 됨

targets = [
    { "boxes": Tensor[N1,4], "labels": Tensor[N1], ... },
    { "boxes": Tensor[N2,4], "labels": Tensor[N2], ... },
    ...
]

targets는 dict가 아니라 dict들의 list 임.

 

이 방식은 샘플마다 boxes의 개수가 달라도(= Tensor의 첫 차원이 달라도)
그대로 리스트 구조를 유지하므로,
default_collatestack 제약을 피하면서
detection의 “가변 길이(annotation 수)” 특성 을 자연스럽게 처리할 수 있음.

또한 이는 torchvision.models.detection 계열 모델들이 공식적으로 기대하는 입력 형식이기도 함.


참고: default_collate의 입력과 출력(classification 기준):

default_collate는 classification task를 기준으로 설계되어 있으며,
모든 샘플의 tensor shape가 동일하다는 가정을 전제로 동작함.

2025.04.26 - [Python] - [DL] default collate_fn - PyTorch

 

[DL] default collate_fn - PyTorch

torch.tensor(batch)collate_fn=None 의 collate function 동작PyTorch에서 DataLoader에서 collate_fn=None인 경우, torch.utils.data._utils.collate.default_collate 함수가 기본으로 사용됨.2024.04.09 - [Python] - [PyTorch] Dataset and DataLoade

ds31x.tistory.com


default_collate 입력

Dataset.__getitem__(idx)(image, label) 형태를 반환하는 경우,

batch = [(img1, label1), (img2, label2), ...]
  • img: 보통 동일한 크기의 Tensor[C, H, W]
  • label: int 또는 Tensor[] (scalar)
  • 모든 샘플에 대해 shape가 동일하다는 것이 암묵적으로 가정이 경우에는 에러가 발생하지 않음.

그러나 일반적인 detection task에서 모든 이미지가 동일한 개수의 object를 가진다고 기대하는 것은 현실적으로 어렵기 때문에,
이러한 방식은 일반적인 detection 데이터셋에는 적합하지 않음.


default_collate 출력

default_collate는 내부적으로 torch.stack을 사용하여 각 요소를 하나의 배치 텐서로 묶음.

images, labels = default_collate(batch)

결과는 다음과 같음:

  • images: Tensor[Batch, C, H, W]
  • labels: Tensor[Batch] (또는 Tensor[Batch, ...])

즉,

images = torch.stack([img1, img2, ...], dim=0)
labels = torch.stack([label1, label2, ...], dim=0)

이 방식은 모든 샘플의 텐서 크기가 동일할 때만 동작하며,
boxes처럼 샘플마다 개수가 달라지는 detection annotation에는 적합하지 않음


Example

관련 gist

https://gist.github.com/dsaint31x/06177dacac78c7faeccc334711c16cac

 

dl_detection_collate_fp.ipynb

dl_detection_collate_fp.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 

 

2개의 image로 batch를 구성할 때, 각 이미지 내의 object의 숫자가 다른 경우를 예제코드는 다음과 같음:

import torch
from torch.utils.data._utils.collate import default_collate

# batch 안에 들어갈 샘플 2개
sample1 = {
    "image": torch.randn(3, 224, 224),
    "boxes": torch.tensor([[0, 0, 10, 10],
                            [20, 20, 30, 30]], dtype=torch.float32),  # 2개
    "labels": torch.tensor([1, 2])
}

sample2 = {
    "image": torch.randn(3, 224, 224),
    "boxes": torch.tensor([[5, 5, 15, 15]], dtype=torch.float32),    # 1개
    "labels": torch.tensor([3])
}

batch = [sample1, sample2]

# default_collate 호출
default_collate(batch)

발생하는 에러는 전형적으로 다음과 같은 메시지임:

RuntimeError: stack expects each tensor to be equal size, but got
[2, 4] at entry 0 and [1, 4] at entry 1

이는 default_collatedict 객체(target)의 각 key별로 요소들을 모은 후, torch.stack 을 사용하여 묶으려고 하기 때문임.

바운딩 박스에 해당하는 boxes 에 대해 내부적 처리는 대략 다음과 같음:

torch.stack([
    sample1["boxes"],  # shape: [2, 4]
    sample2["boxes"],  # shape: [1, 4]
])

문제는 torch.stack은 합치려는 모든 Tensor 객체의 shape가 같아야 하기 때문에 에러가 발생함.


만약, 모든 sample 들 각각이 가지는 objects의 수가 다음과 같이 같다고 하자.

sample1 = {
    "image": torch.randn(3, 224, 224),
    "boxes": torch.tensor([[0, 0, 10, 10]], dtype=torch.float32),  # 1개
    "labels": torch.tensor([1])
}

sample2 = {
    "image": torch.randn(3, 224, 224),
    "boxes": torch.tensor([[5, 5, 15, 15]], dtype=torch.float32),  # 1개
    "labels": torch.tensor([2])
}

batch = [sample1, sample2]

out = default_collate(batch)

print(out["boxes"].shape)   # torch.Size([2, 1, 4])
print(out["labels"].shape)  # torch.Size([2, 1])

이 경우엔 문제가 발생하지 않지만, 일반적인 detection task에서 기대하기는 어려운 상황임.


같이보면 좋은 자료들

2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader

 

[PyTorch] Dataset and DataLoader

1. Dataset 이란 : PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게

ds31x.tistory.com

 

728x90