torchvision.datasets.ImageFolder는 PyTorch에서 Image Classification Task를 위한 Dataset을 쉽게 구성할 수 있게 해주는 클래스임.
정확한 Full qualified name은 torchvision.datasets.folder.ImageFolder 임: torchvision.datasets 에 reexport되어 있음.
torch.utils.data.Dataset
↑
torchvision.datasets.vision.VisionDataset
↑
torchvision.datasets.folder.ImageFolder
original API documentation:
https://docs.pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html
ImageFolder — Torchvision 0.22 documentation
Shortcuts
docs.pytorch.org
관련 gist:
https://gist.github.com/dsaint31x/8dfdf1b4d133a47861a332235e00f033
dl_torchvision_datasets_imagefolder.ipynb
dl_torchvision_datasets_imagefolder.ipynb. GitHub Gist: instantly share code, notes, and snippets.
gist.github.com
Signature
class torchvision.datasets.ImageFolder(
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None
)
root(string): 데이터셋의 루트 디렉토리 경로transform(callable, optional): 이미지에 적용할 변환 함수target_transform(callable, optional): 타겟(레이블)에 적용할 변환 함수loader(callable): 이미지 파일을 로드하는 함수 (기본값: PIL.Image.open())is_valid_file(callable, optional): 파일이 유효한지 확인하는 함수
이는 classification을 위한 Dataset을 반환함.
Directory Structure for Classification
root 디렉토리 (아래 그림에선 ./ds_cls 임)의 구조는 다음과 같음:

이 경우, train과 test를 위한 set가 존재하고, 각 4개의 class로 나뉨.
ImageFolder 객체의 주요 속성.
주요 attributes 와 methods는 다음과 같음:
attrubutes
- meta information
classes: 클래스 이름 리스트class_to_idx: 클래스 이름을 인덱스로 매핑하는 딕셔너리
- data관련
imgs: (이미지 경로, 클래스 인덱스) 튜플의 리스트. 과거의 samples 라고 보면 됨.samples:imgs와 동일. Dataset 의 핵심 데이터 구조임.targets: 모든 샘플의 타겟(class_index) 리스트
- Loading과 Transform 관련
loader: image가 있는 path를 접근 가능한 객체로 (PIL Image / Tensor 등)transform: 이미지에 적용되는 transform 객체target_transform: 타겟(lable)에 적용되는 transform 객체.- mask 나 boundbox에 적용하는 geometric transform이 아닌 one-hot encoding으로 변환 등을 수행.
- mask, bbox, keypoint 에 대한 transform은 torchvision.transforms.v2 의 Transform들을 이용한다.
- detection / segmentation 용으로 사용하기 위해 __getitem__(idx)를 오버라이딩 하는 경우엔 지정하지 않는 것이 좋음.
is_valid_file: 파일 필터링 용 함수. ImageFolder 객체 생성되는 과정에 다음과 같이 호출됨. is_valid_file(path) == True 인 경우만 samples 에 추가됨.
methods
- __len__() : 전체 sample의 수를 반환.
- __add__(): dataset 객체 간의 concatenation으로 사용됨.
- __getitem__(idx)
- dataset[idx] 를 사용할 경우, 호출되는 method임.
- 내부적으로 다음의 호출이 이루어짐.
- path, target = self.sample[idx] : path는 파일 경로, target은 class_to_idx 로 매겨진 정수 Label
- sample = self.loader(path) : path를 로딩. PIL의 Image로 반환되는것이 기본이나, Tensor를 반환하도록 loader를 바꿀수도 있음.
- 이후 transform이 설정된 경우, sample = self.transform(sample) 이 호출됨.
- transforms.v2 의 composit input (복합입력)을 사용하는 경우는 Tensor객체를 사용해야함.
- 초기 변환은 PIL의 image도 되지만, v2.ToDType 이나 v2.Normalize, 그리고 composit input으로 boxex, masks, keypoints를 사용한다면 transform에서 v2.ToImage()로 tensor로 바꾸고 나서 사용해야함.
- 이후 target_transform이 설정된 경우, target = self.target_ransform(sample) 이 호출됨.
- 이후 sample 과 target을 Tuple로 묶어서 반환.
- ImageFolder를 Detection이나 Segmentation으로 확장해서 쓸 경우, __getitem__(idx)를 오버라이딩하여 사용.
- 다음으로 over-riding:
def __getitem__(self, idx: int) -> Tuple[Any, Dict[str, Any]]
- classification의 경우는 다음임:
def __getitem__(self, idx: int) -> Tuple[Any, Dict[str, int]
- 다음으로 over-riding:
참고로, ImageFolder는 transforms를 지원하지 않음.
- classification은 label 에 image에 가해진 변환이 같이 가해질 필요가 없음.
- 즉, composite input을 사용하지 않음.
- v2의 composite input을 사용한다면, 다음과 같은 Transform객체를 transforms에 적용하고,
- wrapper나 커스텀 Dataset에서 __getitem__에서 (image, target)을 명시적으로 처리해야 한다
v2.Compose([
v2.ToImage(), # PIL → Tensor (tv_tensors.Image)
v2.ToDtype(torch.float32),
v2.RandomHorizontalFlip(), # composite input 동기화
])
예제
from torchvision.datasets import ImageFolder
from pathlib import Path
t_dir = './'
dataset = 'ds_cls'
dataset_path = Path(t_dir) / Path(dataset)
train_dataset_path = dataset_path / Path('train')
x_dataset = ImageFolder(
root = train_dataset_path,
)
print(f"{type(x_dataset)=}")
print(f"{len(x_dataset) = }")
print(f"{x_dataset.classes = }")
print(f"{x_dataset.class_to_idx = }")
# Results
# type(x_dataset)=<class 'torchvision.datasets.folder.ImageFolder'>
# len(x_dataset) = 12
# x_dataset.classes = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']
# x_dataset.class_to_idx = {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}
# ------------------------
from PIL import Image
img, label = x_dataset[0]
print(f"이미지 타입: {type(img)}")
if isinstance(img, Image.Image):
print(f"이미지 모드: {img.mode}")
print(f"이미지 크기: {img.size}")
print(f"밴드 수: {len(img.getbands())}")
# Result
# 이미지 타입: <class 'PIL.Image.Image'>
# 이미지 모드: RGB
# 이미지 크기: (640, 480)
# 밴드 수: 3
# ------------------------
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(3,3))
plt.title(x_dataset.classes[x_dataset[0][1]])
plt.imshow(np.array(x_dataset[0][0]))

다음은 간단한 테스트를 위한 파일들과 디렉토리 구조를 만들어서 압축한 파일임.
같이 보면 좋은 자료들
2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader
[PyTorch] Dataset and DataLoader
Dataset 이란PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게 해주
ds31x.tistory.com
2025.06.17 - [Python] - [torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel
[torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel
torchvision.transforms는 PyTorch에서 제공하는 이미지 preprocessing 및 data augmentation을 위한 module. 이 모듈은 현재 v2 서브모듈의 사용을 권함:v2 transforms는 image뿐만 아니라bounding boxes, masks, videos도 변환할
ds31x.tistory.com
'Python' 카테고리의 다른 글
| Pillow 사용법 - Basic 01 (2) | 2025.06.30 |
|---|---|
| [torchvision] torchvision.utils.save_image and torchvision.io.encode_jpeg, torchvision.io.encode_png (0) | 2025.06.17 |
| [torchvision] image로부터 torch.tensor 객체 얻기 (0) | 2025.06.17 |
| [torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel (0) | 2025.06.17 |
| [PyTorch] Conversion-Torchvision.transfroms.v2 (0) | 2025.06.16 |