본문 바로가기
목차
Python

torchvision.datasets.ImageFolder 사용하기.

by ds31x 2025. 6. 17.
728x90
반응형

torchvision.datasets.ImageFolder는 PyTorch에서  Image Classification Task를 위한 Dataset을 쉽게 구성할 수 있게 해주는 클래스임.

 

정확한 Full qualified name은 torchvision.datasets.folder.ImageFolder 임: torchvision.datasets 에 reexport되어 있음.

torch.utils.data.Dataset
        ↑
torchvision.datasets.vision.VisionDataset
        ↑
torchvision.datasets.folder.ImageFolder

 

original API documentation:

https://docs.pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html

 

ImageFolder — Torchvision 0.22 documentation

Shortcuts

docs.pytorch.org

 

관련 gist:

https://gist.github.com/dsaint31x/8dfdf1b4d133a47861a332235e00f033

 

dl_torchvision_datasets_imagefolder.ipynb

dl_torchvision_datasets_imagefolder.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com


Signature

class torchvision.datasets.ImageFolder(
    root: str,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    loader: Callable[[str], Any] = default_loader,
    is_valid_file: Optional[Callable[[str], bool]] = None
)
  • root (string): 데이터셋의 루트 디렉토리 경로
  • transform (callable, optional): 이미지에 적용할 변환 함수
  • target_transform (callable, optional): 타겟(레이블)에 적용할 변환 함수
  • loader (callable): 이미지 파일을 로드하는 함수 (기본값: PIL.Image.open())
  • is_valid_file (callable, optional): 파일이 유효한지 확인하는 함수

이는 classification을 위한 Dataset을 반환함.


Directory Structure for Classification

root 디렉토리 (아래 그림에선 ./ds_cls 임)의 구조는 다음과 같음:

이 경우, train과 test를 위한 set가 존재하고, 각 4개의 class로 나뉨.


ImageFolder 객체의 주요 속성.

주요 attributes 와 methods는 다음과 같음:

 

attrubutes

  • meta information
    • classes: 클래스 이름 리스트
    • class_to_idx: 클래스 이름을 인덱스로 매핑하는 딕셔너리
  • data관련
    • imgs: (이미지 경로, 클래스 인덱스) 튜플의 리스트. 과거의 samples 라고 보면 됨.
    • samples: imgs와 동일. Dataset 의 핵심 데이터 구조임.
    • targets: 모든 샘플의 타겟(class_index) 리스트
  • Loading과 Transform 관련
    • loader : image가 있는 path를 접근 가능한 객체로 (PIL Image / Tensor 등)
    • transform : 이미지에 적용되는 transform 객체
    • target_transform : 타겟(lable)에 적용되는 transform 객체.
      • mask 나 boundbox에 적용하는 geometric transform이 아닌 one-hot encoding으로 변환 등을 수행.
      • mask, bbox, keypoint 에 대한 transform은 torchvision.transforms.v2 의 Transform들을 이용한다.
      • detection / segmentation 용으로 사용하기 위해 __getitem__(idx)를 오버라이딩 하는 경우엔 지정하지 않는 것이 좋음.
    • is_valid_file : 파일 필터링 용 함수. ImageFolder 객체 생성되는 과정에 다음과 같이 호출됨. is_valid_file(path) == True 인 경우만 samples 에 추가됨.

methods

  • __len__() : 전체 sample의 수를 반환.
  • __add__(): dataset 객체 간의 concatenation으로 사용됨.
  • __getitem__(idx)
    • dataset[idx] 를 사용할 경우, 호출되는 method임.
    • 내부적으로 다음의 호출이 이루어짐.
      • path, target = self.sample[idx] :  path는 파일 경로, target은 class_to_idx 로 매겨진 정수 Label
      • sample = self.loader(path) : path를 로딩. PIL의 Image로 반환되는것이 기본이나, Tensor를 반환하도록 loader를 바꿀수도 있음.
      • 이후 transform이 설정된 경우, sample = self.transform(sample) 이 호출됨.
        • transforms.v2 의 composit input (복합입력)을 사용하는 경우는 Tensor객체를 사용해야함.
        • 초기 변환은 PIL의 image도 되지만, v2.ToDType 이나 v2.Normalize, 그리고 composit input으로 boxex, masks, keypoints를 사용한다면  transform에서 v2.ToImage()로 tensor로 바꾸고 나서 사용해야함.
      • 이후 target_transform이 설정된 경우, target = self.target_ransform(sample) 이 호출됨.
      • 이후 sample 과 target을 Tuple로 묶어서 반환.
    •  ImageFolder를 Detection이나 Segmentation으로 확장해서 쓸 경우, __getitem__(idx)를 오버라이딩하여 사용.
      • 다음으로 over-riding:
        • def __getitem__(self, idx: int) -> Tuple[Any, Dict[str, Any]]
      • classification의 경우는 다음임: def __getitem__(self, idx: int) -> Tuple[Any, Dict[str, int]

 

참고로, ImageFolder는 transforms를 지원하지 않음.

  • classification은 label 에 image에 가해진 변환이 같이 가해질 필요가 없음. 
  • 즉, composite input을 사용하지 않음.
  • v2의 composite input을 사용한다면, 다음과 같은 Transform객체를 transforms에 적용하고,
  • wrapper나 커스텀 Dataset에서 __getitem__에서 (image, target)을 명시적으로 처리해야 한다
v2.Compose([
    v2.ToImage(),                 # PIL → Tensor (tv_tensors.Image)
    v2.ToDtype(torch.float32),
    v2.RandomHorizontalFlip(),    # composite input 동기화
])

 


예제


from torchvision.datasets import ImageFolder
from pathlib import Path

t_dir = './'
dataset  = 'ds_cls'

dataset_path = Path(t_dir) / Path(dataset)

train_dataset_path = dataset_path / Path('train')

x_dataset = ImageFolder(
    root = train_dataset_path,
)

print(f"{type(x_dataset)=}")

print(f"{len(x_dataset) = }")
print(f"{x_dataset.classes      = }")
print(f"{x_dataset.class_to_idx = }")

# Results
# type(x_dataset)=<class 'torchvision.datasets.folder.ImageFolder'>
# len(x_dataset) = 12
# x_dataset.classes      = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']
# x_dataset.class_to_idx = {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}

# ------------------------
from PIL import Image

img, label = x_dataset[0]

print(f"이미지 타입: {type(img)}")

if isinstance(img, Image.Image):
    print(f"이미지 모드: {img.mode}")
    print(f"이미지 크기: {img.size}")
    print(f"밴드 수: {len(img.getbands())}")
    
# Result
# 이미지 타입: <class 'PIL.Image.Image'>
# 이미지 모드: RGB
# 이미지 크기: (640, 480)
# 밴드 수: 3

# ------------------------
import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(3,3))
plt.title(x_dataset.classes[x_dataset[0][1]])
plt.imshow(np.array(x_dataset[0][0]))

 

다음은 간단한 테스트를 위한 파일들과 디렉토리 구조를 만들어서 압축한 파일임.

test.zip
0.33MB


같이 보면 좋은 자료들

2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader

 

[PyTorch] Dataset and DataLoader

Dataset 이란PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게 해주

ds31x.tistory.com

2025.06.17 - [Python] - [torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

 

[torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

torchvision.transforms는 PyTorch에서 제공하는 이미지 preprocessing 및 data augmentation을 위한 module. 이 모듈은 현재 v2 서브모듈의 사용을 권함:v2 transforms는 image뿐만 아니라bounding boxes, masks, videos도 변환할

ds31x.tistory.com


 

728x90