본문 바로가기
목차
Python

[PyTorch] torchvision.transforms 사용법 - transforms란?

by ds31x 2025. 1. 12.
728x90
반응형

PyTorch의 torchvision.transforms:
이미지 전처리와 데이터 증강을 위한 도구

torchvision.transforms

  • PyTorch에서 제공하는 이미지 전처리 data augmentation을 위한 module.
  • 이 모듈은 이미지 데이터를 이용한 딥러닝 모델의 학습 효율을 높이고 데이터 준비 과정을 단순화하는 데 사용됨.

현재는 torchvision.transforms.v2 를 대신 사용하는 것이 권장됨(torchvision 0.15가 공개된 2023년 3월 이후):

2025.06.17 - [Python] - [torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

 

[torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

torchvision.transforms는 PyTorch에서 제공하는 이미지 preprocessing 및 data augmentation을 위한 module. 이 모듈은 현재 v2 서브모듈의 사용을 권함:v2 transforms는 image뿐만 아니라bounding boxes, masks, videos도 변환할

ds31x.tistory.com

 


0. PyTorch에서 transform이란?

transform은

  • 개별 sample instance 단위로 적용이 되며,
  • Dataset 체의 __getitem__ 메서드 내부에서 호출되어 사용된다.
일반적으로, Dataset 의 생성자에
transform (or transforms) 파라미터 명으로 넘겨지며,
각 sample instance 단위로 적용됨.

 

DataLoadercollate_fn 에 설정된 "collate function"batch 단위 로 적용되는 것과 차이가 있음을 명심할 것.

 

PyTorch에서

  • Transform은 데이터 인스턴스를 입력받아, 이에 변환을 가한 데이터 인스턴스를 반환하는 일종의 callable 객체이며,
  • Dataset 객체의 __getitem__() 인스턴스 메서드에서 호출된다.
  • 최소한 __call__() 메서드만 지원하면 custom transform로 사용가능함.
transform은
어떤 형태의 callable 객체든 사용할 수 있음.
하지만, torch.nn.Module을 상속하여 구현하는 것이 최근엔 보다 권장됨! 

 

다음은 많이 사용되는 transforms 관련 모듈임.

  • torchvision.transforms
    • 이미지 처리에 특화된 transform 제공
    • PIL Image나 Tensor 이미지용
  • torchaudio.transforms
    • 오디오 처리용 transform
    • Spectrogram, MFCC 등
  • torchtext.transforms
    • 텍스트 처리용 transform
    • 토크나이징, 인코딩 등
PyTorch Domain Libraries
참고로, torchvision, torchaudio, torchtext
PyTorch Domain Libraries (간단하게 PyTorch Libraries라고도 부름)로서 특정 분야에 특화된 라이브러리임.
최근에 torchdata(데이터 로딩), torchserve(모델 서빙), torchrec(추천시스템) 등의 더 많은 domain libraries가 추가되었음.

 


1. torchvision.transforms란 무엇인가?

torchvision.transforms의 주요 역할:

  • 이미지 전처리: 크기 조정, 자르기, 회전 등등.
  • Data Augmentation (데이터 증강)
  • 전처리 파이프라인 구축: 여러 변환을 조합하여 효율적인 데이터 전처리 루틴(각각의 instance 단위로 적용)을 생성.

 

callable 객체를 위한 Transform 클래스들(torch.nn.Module의 subclass들)을 사용하는 방식이 일반적으로 권장되나,
함수 형태로 제공되는 torchvision.transforms.v2.functional 모듈을 사용하는 경우도 많다.

이는 torch.nn 모듈에 정의된 클래스를 이용하여 layer(torch용어로는 module)를 사용하는 방법과
torch.nn.functional
모듈의 함수들로 이용하는 방법이 존재하는 것과 유사함.

1-1. 지원하는 input type

  • PyTorch의 Tensor 객체 (torchvision.tv_tensors 모듈의 TVTensor객체 포함)
  • PIL의 Image객체
  • NumPy의 ndarray객체

performance(성능)을 위해선, PyTorch의 Tensor 객체를 사용하는 것이 권장됨.

이들간의 conversion을 위한 conversion transform들을 지원함.

주로v2.ToImagev2.ToDtype 로 구성된 Compose 사용하여 conversion을 수행.


1-2. Expected Value Range

  • PyTorch의 Tensor 객체는 torch.float32 (GPU지원 위해)의 [0.0, 1.0]을 기본으로 사용함.
  • 하지만, raw image의 경우, torch.uint8의 [0,255]가 기본임.
  • value range 변환도 주로 v2.ToDtype를 사용

1-3. Shape Convention.

  • Tensor Image는 Channels, Height, Width 의 shape를 가정함.
  • Batch Size 를 N으로 표시시 (N,C,H,W)를 사용.
  • v2의 경우엔 C,H,W 앞에 다양한 leading dimensions가 놓일 수 있음: (..., C, H, W)

2. 주요 Transform 클래스와 사용법

 

기존의 torchvision.transforms 대신에
torchvision.transforms.v2 를 사용하는 것이
2023년 3월에 릴리즈 된 Torchvision 0.15 이후부터 권장된다.

 

2-1. 이미지 크기 조정

transforms.Resize:

  • 이미지를 지정된 크기로 변경.
  • transforms.v2.Resize 로 대체하여 사용하는 것을 권장함.
# 목표 크기 (height, width)를 지정: aspect ratio가 변경될 수 있음.
transform = transforms.Resize((128, 128))

# callable객체이므로 다음과 같이 함수처럼 호출하여 사용.
resized_image = transform(image)

transforms.CenterCrop:

  • 중앙을 기준으로 이미지를 cropping (자름).
  • transforms.v2.CenterCrop 으로 대체하여 사용하는 것을 권장함.
# 이미지 중앙에서 지정된 크기 (height, width)만큼 잘라내기: 
# - aspect ratio 유지됨.
# - crop 크기가 원본보다 클 경우 패딩 추가
transform = transforms.CenterCrop((100, 100))

# callable객체이므로 다음과 같이 함수처럼 호출하여 사용.
cropped_image = transform(image)

transforms.RandomCrop:

  • 임의의 위치에서 이미지를 자름.
  • padding 으로 crop 전에 padding을 수행 가능.
  • transforms.v2.RandomCrop 으로 대체하여 사용하는 것을 권장함.
# 이미지에서 랜덤한 위치에서 지정된 크기 (height, width)만큼 잘라내기: 
# - Data Augmentation 으로 사용됨.
# - padding=4: 원본 이미지 상하좌우에 4픽셀씩 패딩 추가 후 크롭
# - 매번 호출할 때마다 다른 랜덤 위치에서 crop 수행
transform = transforms.RandomCrop((100, 100), padding=4)

# callable객체이므로 다음과 같이 함수처럼 호출하여 사용.
random_cropped_image = transform(image)

2-2. 텐서 변환 (Conversion)

transforms.ToTensor: (현재는 새로 작성하는 코드에선 사용치 않는게 권장됨)

  • 이미지를 PyTorch 텐서로 변환.
  • 픽셀 값을 [0.0, 1.0]normalization(정규화):
    • 입력 데이터가  Pillow 의 Image 객체인 경우 수행됨.
    • NumPy의 ndarray이면서 float 형이 아닐 경우 수행됨.
  • 입력데이터가 Pillow 의 Image 객체인 경우,
    • 입력이 Width, Height, Channel 이라고 가정하고
    • 출력은 Channel, Width, Height로 변경됨.
transform = transforms.ToTensor()
tensor_image = transform(image)

 

ToTensordeprecated 된 상태나 다름없음 (너무 많은 기능을 가지고 있음). 
v2.ToImagev2.ToDtype 구성된 v2.Compose 사용이 권장된다 (v0.16 이후).

 

다음의 코드를 참고:

# 대체 방법 (최신)
v2.Compose([
    v2.ToImage(),                           # PIL/numpy → tensor 변환
    v2.ToDtype(torch.float32, scale=True)   # [0,255] → [0.0,1.0] float 변환 및 스케일링
])

transforms.Normalize:

  • 텐서를 normalization하기 위해 standardization을 수행.
  • 각 채널에 대해 (x - mean) / std를 적용.
  • v2.Normalize 를 권장.
transform = transforms.Normalize(
    mean=[0.5, 0.5, 0.5], 
    std=[0.5, 0.5, 0.5],
)
normalized_image = transform(tensor_image)

2-3. Data Augmentation (데이터 증강)

transforms.RandomHorizontalFlip:

  • 이미지 좌우 반전(Horizontal Flip)을 확률적으로 적용.
  • v2.RandomHorizontalFlip 으로 대체를 권함.
# 이미지를 수평(좌우)으로 stochastic하게 뒤집기: Data augmentation.
# - p=0.5: 50% 확률로 좌우 반전 수행
# 참고:
# - p=0.0: 절대 뒤집지 않음 (항상 원본)
# - p=1.0: 항상 뒤집음 (deterministic)
# - p=0.5: 절반 확률로 뒤집음 (일반적인 설정)
transform = transforms.RandomHorizontalFlip(p=0.5)

# callable객체이므로 다음과 같이 함수처럼 호출하여 사용.
flipped_image = transform(image)

# ----------
# 권장 방식.
from torchvision.transforms import v2

transform = v2.RandomHorizontalFlip(p=0.5)
flipped_image = transform(image)

transforms.RandomRotation:

  • 이미지를 랜덤 각도로 회전.
    • 회전 시 빈 공간은 기본적으로 검은색(0)으로 채움
    • 이미지 크기는 유지되므로 모서리 부분이 잘릴 수 있음
  • v2.RandomRotation 으로 대체를 권함.
# 이미지를 랜덤한 각도로 회전: 데이터 증강용으로 기하학적 변형 제공.
# - degrees=30: -30도에서 +30도 사이의 랜덤한 각도로 회전
# - degrees=(15, 45): 최소 15도, 최대 45도 범위 지정 가능
# - degrees=(-90, 90): 음수/양수로 시계/반시계 방향 제어
transform = transforms.RandomRotation(degrees=30)

# callable객체이므로 다음과 같이 함수처럼 호출하여 사용.
rotated_image = transform(image)

 

다음은 v2.RandomRotation의 사용예임.

from torchvision.transforms import v2
import torch

# 기본 사용법
transform = v2.RandomRotation(
    degrees=30,                    # ±30도 범위
    interpolation=v2.InterpolationMode.BILINEAR,  # 보간법
    expand=False,                  # 이미지 크기 확장 여부
    center=None,                   # 회전 중심점
    fill=0                         # 빈 공간 채울 값/색상
)

transforms.ColorJitter:

  • brightness(밝기), contrast(대비), saturation(채도), hue(색상)를 랜덤하게 조정.
  • Parameters: 0에서 1 사이의 값을 사용하며, 지정된 범위 내에서 무작위로 변환이 적용됨.
    • brightness: 밝기 조정 범위 (예: 0.2는 원본 대비 0.8~1.2배 사이의 밝기로 변환)
    • contrast: 대비 조정 범위
    • saturation: 채도 조정 범위
    • hue: 색상 조정 범위
# color 속성을 랜덤하게 변경하는 data augmentation: 
transform = transforms.ColorJitter(brightness=0.2, contrast=0.3)

# callable 객체이므로 다음과 같이 함수처럼 호출하여 사용
jittered_image = transform(image)
  • brightness=0.2: 밝기를 원본의 80%~120% 범위에서 랜덤 조정, 기본값 0.
  • contrast=0.3: 대비를 원본의 70%~130% 범위에서 랜덤 조정, 기본값 0.
  • saturation=0.1: 채도를 원본의 90%~110% 범위에서 랜덤 조정, 기본값 0.
  • hue=0.1 : 색상을 원본에 대해 [-0.1, 0.1] 범위 ( ±36도 내 회전)에서 랜덤 조정, 기본값 0.

v2.ColorJitter 으로 대체한 코드는 다음과 같음:

# v2 버전으로 대체
from torchvision.transforms import v2

transform = v2.ColorJitter(brightness=0.2, contrast=0.3)
jittered_image = transform(image)

2-4. 기타

transforms.Grayscale:

  • 이미지를 Gray Scale (L)로 변환.
  • v2.Lambda로 대체하는 것을 권장함.
transform = transforms.Grayscale(num_output_channels=1)
gray_image = transform(image)

 

v2로 대체한 경우는 다음과 같음:

import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F


# ------------
# 단채널로...
def grayscale_bt601(x: torch.Tensor) -> torch.Tensor:
    # x: [3, H, W], float tensor
    r, g, b = x[0], x[1], x[2]
    y = 0.299 * r + 0.587 * g + 0.114 * b
    return y.unsqueeze(0)  # [1, H, W]

transform = v2.Compose([
    v2.ToImage(), # 이후 tensor로 
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(grayscale_bt601),
])


# # 보다 나은 구현
# # 아래의 rgb_to_grayscale은 Tensor와 tv_tensor.Image 만 지원(PIL 의 이미지 지원X) 
# transform = v2.Compose([
#     v2.ToImage(),
#     v2.ToDtype(torch.float32, scale=True),
#     v2.Lambda(lambda x: v2.functional.rgb_to_grayscale(x, 1)),
# ])

# # 더 나은 구현
# class ToGray(v2.Transform):
#     def forward(self, img):
#         return F.rgb_to_grayscale(img, num_output_channels=1)
# transform = v2.Compose([
#     v2.ToImage(),
#     v2.ToDtype(torch.float32, scale=True),
#     ToGray(),
# ])        

gray_image = transform(image)

# -----------
# 채널 유지.
def grayscale_bt601_3ch(x):
    r, g, b = x[0], x[1], x[2]
    y = 0.299 * r + 0.587 * g + 0.114 * b
    return y.unsqueeze(0).repeat(3, 1, 1)

transforms.Lambda:

  • 사용자 정의 변환을 구현한 lambda expression을 적용.
  • lambda를 사용한 익명함수를 transforms.Lambda에서 사용시 다음의 이슈가 있음을 주의할 것:
    • Multiprocessing 의 경우 문제가 발생: DataLoader에서 num_workers를 1개 이상 사용 불가.
    • 이는 pickle로 serialization이 안되기 때문임.
    • 익명함수는 prototyping 등의 테스트에만 사용할 것.
  • 가급적 이름을 가지는 일반 함수를 사용할 것.
# 익명함수를 transform으로 wrapping
# PIL Image객체 img를 받아, ccw로 45도 회전처리.
transform = transforms.Lambda(lambda img: img.rotate(45))

rotated_image = transform(image)

 

v2.Lambda로 대체한 코드는 다음과 같음.

from torchvision.transforms import v2

# 기존 코드
# transform = transforms.Lambda(lambda img: img.rotate(45))

# v2 직접 변환
transform = v2.Lambda(lambda img: img.rotate(45))
rotated_image = transform(image)

 

익명함수 대신 named function (def 키워드 이용한 일반 함수)을 사용하는 것을 권함:
사실 이보다 더 좋은 것은 custom transform 클래스를 이용하는 것임.

# 함수형 접근 (transform 객체 생성 없이)
rotated_image = v2.functional.rotate(
    image, 
    angle=45,  # 반시계방향 45도
    interpolation=v2.InterpolationMode.BILINEAR,
    expand=False,  # 원본 크기 유지
    fill=0         # 빈 공간을 검은색으로 채움
)

# 함수로 래핑
def rotate_45_ccw(img):
    return v2.functional.rotate(img, angle=45)

# Lambda 대신 네임드 함수 사용
transform = v2.Lambda(rotate_45_ccw)  # 직렬화 가능

 

다음은 Lambda transform을 테스트하는 전체 코드임.

import PIL
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

np_img = np.zeros((10, 10, 3), dtype=np.uint8)
np_img[5,:,:] = 255

img = Image.fromarray(np_img)

# # 샘플 이미지 불러오기
# img = Image.open("IMG_2826.jpeg")

# Lambda transform 정의 및 적용
transform = transforms.Lambda(
    lambda img: img.rotate(
        20,resample=PIL.Image.Resampling.BILINEAR
    )
)
rotated_image = transform(img)

# 결과 확인
plt.imshow(np.array(rotated_image))

 

https://dsaint31.tistory.com/236

 

PIL과 opencv에서의 image 변환.

PIL과 opencv에서의 image 변환.필요성tensorflow 나 pytorch등에서의 이미지 로딩의 경우,일반적으로, PIL.Image.Image를 기본적으로 이미지를 위한 class 타입으로 사용함.from tensorflow.keras.preprocessing import image

dsaint31.tistory.com

 

2023.07.07 - [Python] - [Python] lambda expression and map, filter, reduce.

 

[Python] lambda expression and map, filter, reduce.

Lambda expression (or Lambda Function, Anonymous Function)Python 에서 lambda function (or lambda expression)은 anonymous function(익명함수)를 만드는데 사용됨.function 형태로 code구현의 재사용을 해야하긴 하지만, def문을 이

ds31x.tistory.com


3. 변환 파이프라인 구축

torchvision.transforms.Compose를 통해 여러 transform 객체들을 연결하여 pipeline(파이프라인)을 만들 수 있음.

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

transformed_image = transform(image)

 

다음 url에서 pipeline의 개념을 다시 한번 볼 것:https://dsaint31.tistory.com/829

 

[ML] scikit-learn: Pipeline 사용법

Scikit-learn의 Pipeline은 여러 데이터 처리 과정을 하나로 묶어 효율적으로 실행할 수 있게 해주는 Class.(일반적인) Pipeline 이란?Pipeline은 일반적으로 (데이터) 처리 과정 또는 기계학습 등에서 “여러

dsaint31.tistory.com


4. Custom Transform 작성법

torchvision 0.15.0이 2023년 3월에 릴리스되면서 transforms.v2가 정식 도입되었고,
그 시점부터 image·mask·box·keypoint를 함께 다루는 변환에는
nn.Module보다 v2.Transform 상속이 권장
되고 있음

4-1. 클래스 기반 커스텀 변환

__call__ 메서드를 가진 클래스를 작성하여 새로운 transform을 정의 가능 (과거 방식).

class CustomTransform:
    def __call__(self, img):
        return img.point(lambda x: x * 1.5)  # 밝기를 1.5배 증가

 

보통 torch.nn.Module 을 상속하여 "custom transform 클래스를 정의"하는 것이 권해짐. (v2 도입 이전 방식)

  • (__call__() 메서드 또는) forward() 메서드에서 get_params() 메서드를 호출하여 변환에 필요한 파라미터들의 값을 얻음.
  • 즉, get_params() 메서드에서 변환에 필요한 파라미터들을 반환하고, 이들을 arguments로 이용하여 forward에서 변환이 되도록 수행함.
  • get_params() 메서드는 torchvision의 transform 구현에서 사용하는 helper 메서드torch.nn.Module의 공식 method는 아님.
더보기

https://ds31x.tistory.com/238

 

[PyTorch] Custom Model 과 torch.nn.Module의 메서드들.

Custom Model 만들기0. nn.Module torch.nn.Module은 PyTorch에서 모든 신경망 모델과 계층의 기반이 되는 클래스임.Custom Model (사용자 정의 모델)부터 Built-in Layer(nn.Linear, nn.Conv2d, etc.)까지 전부 nn.Module을 상속

ds31x.tistory.com

 

forward() 메서드 내에서 get_params() 를 호출하여 변환에 파라미터를 반환받는 방식이 일반적이지만...
__call__() 메서드  내에서 get_params() 를 호출하고, 이를 forward() 호출시 인자로 넘겨주도록 overriding도 가능.

 

get_params 패턴 도입에 따른 장점:

  • batch 처리 일관성: batch 전체에 동일한 파라미터 적용 가능함
    • v2에서 제공하는 transform들은 batch 차원을 포함한 tensor 객체 및 tv_tensors를 공식적으로 지원함.
  • 조건부 변환: 이미지 속성에 따른 적응적 파라미터 생성가능함.
    • v2.transform의 경우 make_params(flat_inputs) 를 통해 
    • 이미지의 크기, dtype, mask 존재 여부 등을 확인하고 이에 맞는 적응형 파라미커 생성 가능.
  • 파라미터 의존성: composite input(복합 입력) 변환에서 상호 연관된 파라미터 관리 가능.
    • rotation angle을 box 회전으로 연관시키거나
    • crop 영역과 resize 비율 등을 연관시키는 등의 처리 가능.
  • 디버깅 지원: 파라미터 로깅 및 재현 가능한 변환
  • 성능 최적화:
    • 파라미터 계산과 변환 로직 분리로 효율성 향상 
    • 중복 파라미터 계산을 제거하기 쉽고,composite input (image, mask, bbox, keypoints) 에 대해 동일 연산 재사용이 보다 쉬움.

보다 효율적인 부분을 위해선, v2.Transform을 상속하는게 나음

torchvision v0.15(Transform v2 도입) 이후부터,
composite input (image, mask, box, keypoint)에 대해 동일한 랜덤 파라미터를 자동으로 동기화하고
(tv_tensors, batch, dtype/shape 안전성까지 포함해) 프레임워크 차원에서 일관되게 처리할 수 있다는 점 에서
nn.Module 보다 v2.Transform 상속이 권장됨.

 

다음 코드는 get_params() 헬퍼 메서드를 도입한 CustomTransform 클래스를 만드는 간단한 예임:

import random
import torch
from torchvision.transforms import v2


class AdvancedCustomTransformImageOnly(torch.nn.Module):
    """ v2 transforms의 표준 패턴을 따르는 커스텀 변환
    image-only 간단 버전
    - nn.Module 표준 호출 유지
    - get_params()는 helper
    - forward()에서 params 생성
    """

    def __init__(self, brightness_range=(0.8, 1.2), rotation_range=(-30, 30), p_blur=0.5):        
        super().__init__()
        self.brightness_range = brightness_range
        self.rotation_range = rotation_range
        self.p_blur = p_blur

    def get_params(self):
        """
        변환에 필요한 랜덤 파라미터들을 미리 계산
        forward() 에서 호출됨
        """    
        return {
            "brightness_factor": random.uniform(*self.brightness_range),
            "rotation_angle": random.uniform(*self.rotation_range),
            "apply_blur": (random.random() < self.p_blur),
        }

    def forward(self, img):
        """
        실제 변환 수행 - get_params에서 생성된 파라미터 사용
        동일한 파라미터로 배치의 모든 이미지에 동일한 변환 적용 가능
        """
        params = self.get_params()

        img = v2.functional.adjust_brightness(img, params["brightness_factor"])
        img = v2.functional.rotate(img, angle=params["rotation_angle"])

        if params["apply_blur"]:
            img = v2.functional.gaussian_blur(img, kernel_size=5, sigma=1.0)

        return img

 

__call__() 에서 사용되는 방식 (비추천).

더보기
import torch
from torchvision.transforms import v2
import random

class AdvancedCustomTransform(torch.nn.Module):
    """v2 transforms의 표준 패턴을 따르는 커스텀 변환"""
    
    def __init__(self, brightness_range=(0.8, 1.2), rotation_range=(-30, 30), p_blur=0.5):
        super().__init__()
        self.brightness_range = brightness_range
        self.rotation_range = rotation_range
        self.apply_blur = p_blur
    
    def get_params(self):
        """
        변환에 필요한 랜덤 파라미터들을 미리 계산
        forward() 호출 전에 __call__에서 자동 호출됨
        """
        brightness_factor = random.uniform(*self.brightness_range)
        rotation_angle = random.uniform(*self.rotation_range)
        apply_blur = random.random() > 0.5
        
        return {
            'brightness_factor': brightness_factor,
            'rotation_angle': rotation_angle, 
            'apply_blur': apply_blur
        }
    
    def forward(self, img, **params):
        """
        실제 변환 수행 - get_params에서 생성된 파라미터 사용
        동일한 파라미터로 배치의 모든 이미지에 동일한 변환 적용 가능
        """
        # 1. 밝기 조정
        img = v2.functional.adjust_brightness(img, params['brightness_factor'])
        
        # 2. 회전
        img = v2.functional.rotate(img, angle=params['rotation_angle'])
        
        # 3. 조건부 블러
        if params['apply_blur']:
            img = v2.functional.gaussian_blur(img, kernel_size=5, sigma=1.0)
        
        return img
    
    def __call__(self, img):
        """
        표준 호출 인터페이스
        1. get_params() 호출하여 랜덤 파라미터 생성
        2. forward(img, **params) 호출하여 변환 수행
        """
        params = self.get_params()
        return self.forward(img, **params)

# 사용법
advanced_transform = AdvancedCustomTransform()

 

v2를 고려한 nn.Module을 상속하여 작성한 full code (단, torchvision은 v2.Transform 을 상속한 구현방식을 보다 권장)

더보기
import math
import random
from typing import Dict, Optional, Tuple

import torch
from torchvision.transforms.v2 import functional as F

try:
    # torchvision >= 0.15 Version 호환성 확인
    # tv_tensors 모듈을 통해 BoundingBoxes, Mask 등 전용 Data Type 지원 여부 체크
    from torchvision import tv_tensors
except Exception:
    tv_tensors = None


class AdvancedCustomTransform(torch.nn.Module):
    """
    Object Detection 및 Segmentation을 위한 Custom Data Augmentation 클래스.
    
    [Feature Description]
      - Image 변환 시 Target(Box, Mask, Keypoint)의 Geometric Synchronization(기하학적 동기화) 보장
      - Brightness, Blur, Rotation 등 다양한 Augmentation 지원
      - Bounding Box Rotation 시 AABB(Axis-Aligned Bounding Box) 재계산 로직 적용
    """

    def __init__(
        self,
        brightness_range: Tuple[float, float] = (0.8, 1.2), # Brightness 조절 비율 범위
        rotation_range: Tuple[float, float] = (-30, 30),    # Rotation Angle 범위 (Unit: Degree)
        p_blur: float = 0.5,                                # Blur 적용 Probability(확률)
        blur_kernel_size: int = 5,                          # Blur Kernel Size (홀수 권장)
        blur_sigma: float = 1.0,                            # Gaussian Blur의 Sigma(표준편차)
        default_box_format: str = "xyxy",                   # 기본 Box Format (xyxy, xywh, cxcywh)
    ):
        super().__init__()
        self.brightness_range = brightness_range
        self.rotation_range = rotation_range
        self.p_blur = p_blur
        self.blur_kernel_size = blur_kernel_size
        self.blur_sigma = blur_sigma
        self.default_box_format = default_box_format

    # ------------------------------------------------------------------
    # 1. Random Parameter Generation
    # ------------------------------------------------------------------
    def get_params(self) -> Dict[str, object]:
        """
        단일 Forward Pass 내에서 공유할 Random Parameter 생성.
        Image와 Target이 동일한 Angle과 Setting으로 변환되도록 Synchronization 보장.
        """
        return {
            "brightness_factor": random.uniform(*self.brightness_range),
            "rotation_angle": random.uniform(*self.rotation_range),
            "apply_blur": (random.random() < self.p_blur),
        }

    # ------------------------------------------------------------------
    # 2. Bounding Box Format Utility
    # ------------------------------------------------------------------
    @staticmethod
    def _boxes_to_xyxy(boxes: torch.Tensor, fmt: str) -> torch.Tensor:
        """
        입력된 Box Format을 연산에 용이한 XYXY Format으로 변환.
        """
        fmt = fmt.lower()
        if fmt == "xyxy":
            return boxes
        if fmt == "xywh":
            # xywh: x_min, y_min, width, height
            x, y, w, h = boxes.unbind(dim=-1)
            return torch.stack([x, y, x + w, y + h], dim=-1)
        if fmt == "cxcywh":
            # cxcywh: center_x, center_y, width, height
            cx, cy, w, h = boxes.unbind(dim=-1)
            return torch.stack(
                [cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], dim=-1
            )
        raise ValueError(f"지원하지 않는 Box Format: {fmt}")

    @staticmethod
    def _xyxy_to_boxes(xyxy: torch.Tensor, fmt: str) -> torch.Tensor:
        """
        XYXY Format을 다시 원본 Format으로 Restore(복원).
        """
        fmt = fmt.lower()
        x1, y1, x2, y2 = xyxy.unbind(dim=-1)
        if fmt == "xyxy":
            return xyxy
        if fmt == "xywh":
            return torch.stack([x1, y1, x2 - x1, y2 - y1], dim=-1)
        if fmt == "cxcywh":
            return torch.stack(
                [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1], dim=-1
            )
        raise ValueError(f"지원하지 않는 Box Format: {fmt}")

    # ------------------------------------------------------------------
    # 3. Geometric Transformation Utility (Rotation)
    # ------------------------------------------------------------------
    @staticmethod
    def _rotate_points_xy(xy: torch.Tensor, angle_deg: float, W: int, H: int) -> torch.Tensor:
        """
        Image Center를 기준으로 2D Coordinate Rotation(좌표 회전) 수행.
        """
        # Angle 변환 (Degree -> Radian)
        angle = math.radians(float(angle_deg))
        cos_a = math.cos(angle)
        sin_a = math.sin(angle)

        # Image Center Coordinate 계산
        cx = (W - 1) / 2.0
        cy = (H - 1) / 2.0

        # Center 기준으로 Coordinate Shift (Origin으로 이동)
        x = xy[..., 0].to(torch.float32) - cx
        y = xy[..., 1].to(torch.float32) - cy

        # Rotation Matrix 적용
        xr = x * cos_a - y * sin_a + cx
        yr = x * sin_a + y * cos_a + cy

        # 결과 Coordinate Merge 및 Image Boundary Clamping (0 ~ W-1, H-1)
        out = torch.stack([xr, yr], dim=-1)
        out[..., 0] = out[..., 0].clamp(0, W - 1)
        out[..., 1] = out[..., 1].clamp(0, H - 1)
        return out.to(xy.dtype)

    @classmethod
    def _rotate_xyxy_to_enclosing_xyxy(
        cls, xyxy: torch.Tensor, angle_deg: float, W: int, H: int
    ) -> torch.Tensor:
        """
        Rotated Box를 감싸는 Enclosing Box(AABB, Axis-Aligned Bounding Box) 계산.
        
        [Logic Flow]
        1. Box의 4개 Corner Coordinate 생성
        2. 각 Corner에 Rotation Transform 수행
        3. 변환된 Coordinate들의 Min/Max 값을 계산하여 새로운 Axis-Aligned Box 생성
        """
        x1, y1, x2, y2 = xyxy.unbind(dim=-1)

        # 4개 Corner Coordinate Stack (Top-Left, Top-Right, Bottom-Right, Bottom-Left)
        corners = torch.stack(
            [
                torch.stack([x1, y1], dim=-1),
                torch.stack([x2, y1], dim=-1),
                torch.stack([x2, y2], dim=-1),
                torch.stack([x1, y2], dim=-1),
            ],
            dim=1,
        )

        # Coordinate Rotation 적용
        rc = cls._rotate_points_xy(corners, angle_deg, W, H)

        # 새로운 XYXY Coordinate 도출 (Min/Max 값 추출)
        x_min = rc[..., 0].min(dim=1).values
        y_min = rc[..., 1].min(dim=1).values
        x_max = rc[..., 0].max(dim=1).values
        y_max = rc[..., 1].max(dim=1).values

        return torch.stack([x_min, y_min, x_max, y_max], dim=-1)

    @classmethod
    def _rotate_keypoints_coco(
        cls, kpts: torch.Tensor, angle_deg: float, W: int, H: int
    ) -> torch.Tensor:
        """
        COCO Format Keypoint (x, y, v) Rotation 변환.
        v(visibility) 값은 Preserve(유지)하고 x, y Coordinate만 변환.
        """
        xy = kpts[..., :2]
        v = kpts[..., 2:].clone()
        
        # Coordinate Rotation 적용
        xy_r = cls._rotate_points_xy(xy, angle_deg, W, H)
        
        return torch.cat([xy_r, v.to(kpts.dtype)], dim=-1)

    # ------------------------------------------------------------------
    # 4. Forward (Execution)
    # ------------------------------------------------------------------
    def forward(self, img: torch.Tensor, target: Optional[Dict] = None):
        """
        img: Input Image Tensor [C, H, W]
        target: Target Dictionary (boxes, masks, keypoints 포함)
        """
        
        # Random Parameter 획득 (Synchronization)
        params = self.get_params()

        H = int(img.shape[-2])
        W = int(img.shape[-1])
        angle = float(params["rotation_angle"])

        # ---------------- Image Augmentation ----------------
        # Rotation 적용 (expand=False: 원본 Resolution 유지)
        img = F.rotate(img, angle=angle, expand=False)
        # Brightness Adjustment
        img = F.adjust_brightness(img, params["brightness_factor"])
        # Conditional Blur 적용
        if params["apply_blur"]:
            img = F.gaussian_blur(img, kernel_size=self.blur_kernel_size, sigma=self.blur_sigma)

        # Target이 없는 경우 Image만 Return
        if target is None:
            return img

        out_t = dict(target)

        # ---------------- Masks Transformation ----------------
        if "masks" in out_t and out_t["masks"] is not None:
            masks = out_t["masks"]

            # tv_tensors.Mask Type 처리
            if tv_tensors is not None and isinstance(masks, tv_tensors.Mask):
                # Mask Value(Class ID) 보존을 위해 NEAREST Interpolation 사용
                out_t["masks"] = F.rotate(
                    masks,
                    angle=angle,
                    expand=False,
                    interpolation=F.InterpolationMode.NEAREST,
                )
            else:
                # 일반 Tensor Type 처리
                if masks.ndim == 2:
                    masks = masks.unsqueeze(0)

                rotated = []
                for m in masks:
                    # 개별 Mask Channel별 Rotation 수행
                    rr = F.rotate(
                        m.unsqueeze(0),
                        angle=angle,
                        expand=False,
                        interpolation=F.InterpolationMode.NEAREST,
                    )
                    rotated.append(rr.squeeze(0))
                out_t["masks"] = torch.stack(rotated, dim=0)

        # ---------------- Boxes Transformation (AABB) ----------------
        if "boxes" in out_t and out_t["boxes"] is not None:
            boxes = out_t["boxes"]

            # tv_tensors.BoundingBoxes Type 처리
            if tv_tensors is not None and isinstance(boxes, tv_tensors.BoundingBoxes):
                box_fmt = boxes.format.value.lower()
                # 1. XYXY Format 변환
                xyxy = self._boxes_to_xyxy(boxes.as_subclass(torch.Tensor), box_fmt)
                # 2. Rotation 및 Enclosing Box(AABB) 계산
                rotated_xyxy = self._rotate_xyxy_to_enclosing_xyxy(xyxy, angle, W, H)
                # 3. Object Regeneration
                out_t["boxes"] = tv_tensors.BoundingBoxes(
                    rotated_xyxy,
                    format=tv_tensors.BoundingBoxFormat.XYXY,
                    canvas_size=boxes.canvas_size,
                )
            else:
                # 일반 Tensor Type 처리
                box_fmt = out_t.get("box_format", self.default_box_format)
                # 1. XYXY Format 변환
                xyxy = self._boxes_to_xyxy(boxes, box_fmt)
                # 2. Rotation 및 Enclosing Box(AABB) 계산
                rotated_xyxy = self._rotate_xyxy_to_enclosing_xyxy(xyxy, angle, W, H)
                # 3. 원본 Format으로 Restore
                out_t["boxes"] = self._xyxy_to_boxes(rotated_xyxy, box_fmt)

        # ---------------- Keypoints Transformation ----------------
        if "keypoints" in out_t and out_t["keypoints"] is not None:
            out_t["keypoints"] = self._rotate_keypoints_coco(
                out_t["keypoints"], angle, W, H
            )

        return img, out_t

torchvision에서는

  • compoiste input(image,mask,box,keypoint)을 일관되게 처리하기 위해 v2.Transform 상속을 권장하지만,
  • nn.Module을 사용해도 기본적인 호환성에는 큰 문제가 없으므로
  • 초보자에게는 nn.Module 상속 방식도 충분히 합리적인 선택임.

 


4-2. function based custom transfrom

단순한 Transform은 function(함수)로 구현할 수 있음.

  • transforms.Lambda를 활용하여 간단히 추가 가능.
  • 또는 functional 모듈의 함수들을 이용할 수도 있음.
def custom_transform(img):
    """PIL 이미지를 흑백으로 변환하는 함수"""
    return img.convert("L")  # PIL의 convert("L") - RGB를 Grayscale로 변환

transform = transforms.Compose([
    transforms.Resize((128, 128)),        # 이미지 크기를 128x128로 조정
    transforms.Lambda(custom_transform),  # 커스텀 함수를 transform으로 래핑 
                                          # 문제: pickle 불가, multiprocessing 호환 안됨
    transforms.ToTensor()                 # PIL → Tensor 변환 + [0,255] → [0,1] 정규화
])

 

역시 v2로 대체하는 것이 권장됨:

from torchvision.transforms import v2

# v2 권장 방식: Global 함수 + Lambda (최후 수단)
def safe_grayscale(img):
    """Global scope의 함수 - pickle 가능"""
    return v2.functional.rgb_to_grayscale(img, num_output_channels=1)

transform_v2_lambda = v2.Compose([
    v2.Resize((128, 128)),                    # 크기 조정
    v2.Lambda(safe_grayscale),                # 글로벌 함수는 pickle 가능 
                                              # (하지만 클래스 방식 권장)
    v2.ToDtype(torch.float32, scale=True)     # Tensor 변환 + 정규화
])

 

가급적이면 CustomTransform 클래스로 만들어서 처리하는 것이 권장됨.

from torchvision.transforms import v2

# v2 권장 방식: 커스텀 클래스 (고급 사용자용)
class GrayscaleTransform(torch.nn.Module):
    """RGB를 흑백으로 변환하는 안전한 커스텀 변환"""
    
    def forward(self, img):
        # v2.functional 사용으로 PIL/Tensor 모두 처리
        return v2.functional.rgb_to_grayscale(img, num_output_channels=1)

transform_v2_custom = v2.Compose([
    v2.Resize((128, 128)),                    # 크기 조정
    GrayscaleTransform(),                     # 커스텀 흑백 변환 (pickle 가능, multiprocessing 호환)
    v2.ToDtype(torch.float32, scale=True)     # Tensor 변환 + 정규화
])

 

물론, v2모듈에서 기본으로 지원하는 기능이면 기본으로 지원하는 클래스를 사용하는 것이 가장 좋음.

앞서의 예처럼 PIL이미지를 Gray Scale로 바꾸는 예는 v2 모듈의 클래스들로 충분히 처리 가능함.

from torchvision.transforms import v2

# v2 권장 방식: 내장 기능 활용 (가장 권장)
transform_v2_builtin = v2.Compose([
    v2.Resize((128, 128)),                    # 이미지 크기를 128x128로 조정 (PIL/Tensor 모두 지원)
    v2.Grayscale(num_output_channels=1),      # RGB → Grayscale 변환 (내장 기능, Lambda보다 안전)
    v2.ToDtype(torch.float32, scale=True)     # Tensor 변환 + 정규화 (ToTensor 대체)
])

5. DataLoader와 결합하여 GPU로 데이터 옮기기

torchvision.transforms

  • Dataset 결합하여 대규모 데이터셋에 자동으로 변환을 적용.
  • DataLoaderDataset 객체를 통해 데이터를 로드할 때, CPU에서 동작함.
  • 때문에 Dataset 객체의 __getitem__ 메서드 내부에서 호출되는 Transform 객체도 CPU에서 동작.
  • 때문에, DataLoader 가 Transform이 적용된 batch를 반환하면 이를 GPU로 이동시켜야 한다.

때문에, 학습 중 GPU를 효율적으로 활용하려면 최종 데이터(Transform이 적용된 Tensor 객체)를 GPU로 이동시켜야함.


5-1. 데이터셋 및 DataLoader 정의

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
dataset = ImageFolder(
            root='path_to_images', 
            transform=transform,
          )
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)

2025.06.17 - [Python] - torchvision.datasets.ImageFolder 사용하기.

 

torchvision.datasets.ImageFolder 사용하기.

torchvision.datasets.ImageFolder는 PyTorch에서 Image Classification Task를 위한 Dataset을 쉽게 구성할 수 있게 해주는 클래스임. original API documentation:https://docs.pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolde

ds31x.tistory.com


5-2. GPU로 데이터 전송 및 학습 루프

import torch
import torch.nn as nn
import torch.optim as optim

# 간단한 모델 정의
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(16 * 128 * 128, 10)
).to('cuda')  # 모델을 GPU로 이동

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 루프
for epoch in range(5):
    for images, labels in dataloader:
        images = images.to('cuda', non_blocking=True)
        labels = labels.to('cuda', non_blocking=True)

        # 순전파 및 역전파
        outputs = model(images)

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')

6. 실전 팁

  • Data Augmentation과 Pre-Processing 분리:
    • Data Augmentation은 Training 단계에서만 사용
    • Evaluation/Test 단계에서는 Pre-Processing만 적용.
    • 가장 명확한 방법은 각 단계별로 개별 Dataset객체를 사용하는 것임
    • 개별 Dataset은 각기 다른 transform을 설정
  • 고정 메모리와 비동기 전송 활용:
    • pin_memory=True
    • non_blocking=True를 사용해
    • GPU 전송 성능을 최적화.
  • GPU 활용 여부 확인:
    • 학습 전에 GPU가 사용 가능한지 확인.
    • device = 'cuda' if torch.cuda.is_available() else 'cpu'

같이 보면 좋은 자료

2025.06.17 - [Python] - [torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

 

[torchvision] transforms.v2, transforms.v2.functional, 그리고 kernel

torchvision.transforms는 PyTorch에서 제공하는 이미지 preprocessing 및 data augmentation을 위한 module. 이 모듈은 현재 v2 서브모듈의 사용을 권함:v2 transforms는 image뿐만 아니라bounding boxes, masks, videos도 변환할

ds31x.tistory.com

 

2025.06.16 - [Python] - [PyTorch] torchvision.transforms.v2 - Summary (작성중)

 

[PyTorch] torchvision.transforms.v2 - Summary (작성중)

다음의 공식문서를 기반으로 정리한 것임.https://docs.pytorch.org/vision/main/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py Illustration of transforms — Torch

ds31x.tistory.com

 

2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader

 

[PyTorch] Dataset and DataLoader

Dataset 이란PyTorch 의 tensor 와학습에 사용될 일반 raw data (흔히, storage에 저장된 파일들) 사이에 위치하며,raw-data로부터 PyTorch의 module 객체 등이 접근가능한 데이터 셋을 추상화한 객체를 얻게 해주

ds31x.tistory.com


 

728x90