본문 바로가기
목차
Python

[PyTorch] Miscellanceous-Torchvision.transforms.v2.functional

by ds31x 2025. 6. 11.
728x90
반응형

torchvision.transforms.v2.functional 모듈에서 제공하는

이미지 데이터에서 normalization 등의 miscellaneous 에 속하는 함수들을 소개함.

 

관련 gist 파일

https://gist.github.com/dsaint31x/86a6ce6f612f79ef1ac33c442376b518

 

dl_misc-torchvision-transforms-v2-functional.ipynb

dl_misc-torchvision-transforms-v2-functional.ipynb - dl_misc-torchvision-transforms-v2-functional.ipynb

gist.github.com

 


0. Prerequisites

torchvision.transforms.v2.functional 모듈은,

  • 이미지 tensor 객체를 입력으로 받아
  • 해당 이미지에 직접 적용할 수 있는 다양한 이미지 변환 함수들을 제공함.

제공되는 함수들은
torchvision.transfroms.v2 의 Transform 클래스들과 달리,

  • 상태(state)를 가지지 않으며, 입력 텐서(이미지 등)를 인자로 직접 받아 변환된 텐서를 즉시 반환함.
  • 간단하고 직접적이며, 변환을 한 번만 적용하거나 사용자 정의 변환 함수 내에서 다른 함수들을 조합할 때 유용
  • 여러 변환을 순차적으로 적용하는 복잡한 파이프라인을 구축할 경우,
  • 각 함수 호출마다 반복적인 인자 전달이 필요하여 다소 불편할 수 있음.

colab에서 다음의 코드들을 수행시켜서 이미지를 다운로드하고 아래의 예제 코드를 수행할 수 있도로 함:

from torchvision.transforms.v2.functional import (
    crop,
    resized_crop,
    center_crop,
    five_crop,
    ten_crop,
)

img_path = "assets/astronaut.jpg"
img_url = "https://raw.githubusercontent.com/pytorch/vision/main/gallery/assets/astronaut.jpg"

!mkdir -p assets/coco/images
!curl -o assets/astronaut.jpg {img_url}

from torchvision.io import decode_image

original_img = decode_image(img_path)

print(f" {type(original_img) = }\n \
{original_img.dtype = }\n \
{original_img.shape = }")

 

다음은 이미지 출력을 위한 plot 함수임:

더보기
# https://github.com/pytorch/vision/tree/main/gallery/
# 위의 torchvision관련 예제들의 display를 위한 plot함수를 그대로 가져옴.

import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

1. v2.functional.normalize(inpt, mean, std[, ...])

  • 역할:
    • 입력 이미지를 평균(mean)과 표준편차(std)를 사용하여
    • 정규화(normalize) 수행.
  • 사용 시점:
    • 모델 학습 전 이미지 데이터를 표준화할 때 사용.
    • 특정 데이터셋의 통계에 맞춰 이미지를 정규화할 때 사용.
  • 주요 파라미터:
    • inpt: 정규화할 입력 이미지
    • mean: 각 채널별 평균값
    • std: 각 채널별 표준편차
# uint8 Tensor를 float31 Tensor로 변환
img_float32 = to_dtype(
    original_img, 
    dtype=torch.float32, 
    scale=True,
    )

# noramlization with mean and std.
normalized_img = normalize(
    img_float32, 
    mean=[0.485, 0.456, 0.406], # ImageNet's mean
    std=[0.229, 0.224, 0.225],  # ImageNet's std
    )


# 결과 확인
print(f"{type(img_float32)    = }")
print(f"{type(normalized_img) = }")
print(f"{img_float32.min().item()    = } \n{img_float32.max().item()    = }")
print(f"{normalized_img.min().item() = } \n{normalized_img.max().item() = }")

# Result
# type(img_float32)    = <class 'torch.Tensor'>
# type(normalized_img) = <class 'torch.Tensor'>
# img_float32.min().item()    = 0.0 
# img_float32.max().item()    = 1.0
# normalized_img.min().item() = -2.1179039478302 
# normalized_img.max().item() = 2.640000104904175

2. v2.functional.erase(inpt, i, j, h, w, v[, ...])

  • 역할:
    • 이미지의 특정 영역을 지정된 값으로
    • 지우기(erase) 수행.
  • 사용 시점:
    • 데이터 증강(data augmentation)을 위해 RandomErase를 적용할 때 사용.
    • 이미지의 일부분을 마스킹하거나 가릴 때 사용.
  • 주요 파라미터:
    • inpt: 지울 영역이 있는 입력 이미지
    • i, j: 지울 영역의 시작 좌표
    • h, w: 지울 영역의 높이와 너비
    • v: 지운 영역을 채울 값
erased_image = erase(original_img, i=50, j=150, h=100, w=100, v=0)
# 결과 확인
print(f"{type(original_img) = }")
print(f"{type(erased_image) = }")
print(f"{original_img.dtype = }")
print(f"{erased_image.dtype = }")

# Result
# type(original_img) = <class 'torch.Tensor'>
# type(erased_image) = <class 'torch.Tensor'>
# original_img.dtype = torch.uint8
# erased_image.dtype = torch.uint8


3. v2.functional.sanitize_bounding_boxes(...[, ...])

  • 역할:
    • 잘못되거나 유효하지 않은 다음의 bounding box를 제거
      • x2 <= x1 또는 y2 <= y1
      • min_size 보다 작은 변의 박스.
      • min_area 보다 작은 면적의 박스
      • 경계 밖의 박스.
    • 해당하는 indexing mask를 반환: 제거된 박스 위치에 False
  • 사용 시점:
    • 객체 탐지 작업에서 bounding box 데이터를 정리할 때 사용.
    • 유효하지 않은 박스들을 필터링할 때 사용.
    • 반드시 clamp_bounding_boxes 를 호출후 사용할 것: 안그러면 원치않는 바운딩박스가 제거되기 쉬움.
    • RandomIoUCrop 를 사용할 경우, 이 변환 이후 반드시 호출할 것(사이에 clamp_bounding_boxes 하는게 좋음.
  • 주요 파라미터:
    • bounding_boxes : Tensor
      • [N, 4] 형태의 shape.
    • format : BoundingBoxFormat 또는 str 
      • "XYXY" or BoundingBoxFormat.XYXY
      • "XYWH" or BoundingBoxFormat.XYWH
      • "CXCYWH" or BoundingBoxFormat.CXCYWH
    • canvas_size : Tuple[int, int]
      • 이미지 크기
    • min_size : float
      • 기본값 1.0 으로 최소 변의 길이임.
      • 이보다 작으면 제거됨.
    • min_area : float
      • 기본값 1.0 으로 최소 면적임.
      • 이보다 작으면 제거됨.
from torchvision.tv_tensors import BoundingBoxes

# 바운딩 박스 데이터 (일부는 유효하지 않음)
boxes_data = torch.tensor([
    [10, 10, 100, 100],   # 유효한 박스
    [50, 50, 40, 40],     # 유효하지 않은 박스 : x1 > x2
    [200, 200, 250, 250], # 유효한 박스
    [0, 0, 0, 0],         # 유효하지 않은 박스 : 크기가 0인 박스
])

boxes = BoundingBoxes(
    boxes_data, 
    format="XYXY", 
    canvas_size=original_img.shape[-2:],
    )

# 유효하지 않은 박스 제거
sanitized_boxes, mask = sanitize_bounding_boxes(boxes)

print(f"Original boxes: {len(boxes)}")
print(f"Valid boxes:    {len(sanitized_boxes)}")
print(f"Valid mask:     {mask}")

# Results
# Original boxes: 4
# Valid boxes:    2
# Valid mask:     tensor([ True, False,  True, False])

4. v2.functional.clamp_bounding_boxes(inpt[, ...])

  • 역할:
    • bounding box가 이미지 경계를 벗어나지 않도록
    • 범위를 제한(clamp).
  • 사용 시점:
    • bounding box 가 이미지 범위 내에 있도록 보장할 때 사용.
    • 변환 후 박스 좌표를 이미지 크기에 맞춰 조정할 때 사용.
    • RandomCrop, RandomResizedCrop 등과 Data augmentation 이후에는 반드시 해줘야 함.
    • 과거 torchvision.ops.clip_boxes_to_image 에 대응.
  • 주요 파라미터:
    • bounding_boxes: torchvision.tv_tensors.BoundingBoxes 제한할 바운딩 박스 데이터
      • [N, 4] 형태의 shape.
      • canvas_size attribute가 반드시 필요함.
from torchvision.tv_tensors import BoundingBoxes

# 이미지 경계를 벗어나는 바운딩 박스
boxes_data = torch.tensor([
    [10, 10, 100, 100],    # 범위 내
    [-10, -10, 50, 50],    # 음수 좌표
    [200, 200, 550, 550]   # 이미지 크기 초과
])

boxes = BoundingBoxes(
    boxes_data, 
    format="XYXY", 
    canvas_size=original_img.shape[-2:],
    )

# 바운딩 박스를 이미지 경계 내로 제한
clamped_boxes = clamp_bounding_boxes(boxes)

print(f"Original boxes:\n{boxes}")
print(f"Clamped boxes:\n{clamped_boxes}")

5. v2.functional.clamp_keypoints(inpt[, ...]) : Not Supported (0.21.0+cu124)

  • 역할:
    • 키포인트가 이미지 경계를 벗어나지 않도록
    • 범위를 제한(clamp).
  • 사용 시점:
    • 포즈 추정 등에서 키포인트가 이미지 범위 내에 있도록 보장할 때 사용.
    • 변환 후 키포인트 좌표를 이미지 크기에 맞춰 조정할 때 사용.
  • 주요 파라미터:
    • inpt: 제한할 키포인트 데이터

6. v2.functional.uniform_temporal_subsample(...)

  • 역할:
    • 시간적 차원에서 균등한 간격으로
    • subsampling수행.
  • 사용 시점:
    • 비디오 데이터에서 프레임을 균등하게 샘플링할 때 사용.
    • 시계열 데이터의 길이를 줄이면서 시간적 특성을 유지할 때 사용.
  • 주요 파라미터:
    • 시간적 데이터 및 샘플링 관련 파라미터들
    • video : Tensor
      • [..., T, C, H, W] : 채널수는 보통 3.
    • num_samples : int
      • 선택할 frame의 개수.
      • 주어진 frame수보다 많으면 nearest neighbor interpolation으로 upsampling.
      • 보통은 작게 주어지면 이 경우 균등 간격으로 num_samples로 subsampling.
# 비디오 텐서 (T, C, H, W) - 30프레임
video = torch.rand(30, 3, 224, 224)

# 30프레임에서 8프레임으로 균등 서브샘플링
subsampled_video = uniform_temporal_subsample(video, num_samples=8)
print(f"Original video frames  : {video.shape[0]}")
print(f"Subsampled video frames: {subsampled_video.shape[0]}")

# 선택된 프레임 인덱스 확인
indices = torch.linspace(0, 29, 8).long()
print(f"Selected frame indices : {indices}")

# Results
# Original video frames  : 30
# Subsampled video frames: 8
# Selected frame indices : tensor([ 0,  4,  8, 12, 16, 20, 24, 29])

7. v2.functional.jpeg(image, quality)

  • 역할:
    • 이미지에 JPEG 압축을 적용하여
    • 이미지 품질을 조절.
  • 사용 시점:
    • 데이터 증강을 위해 이미지 품질을 의도적으로 낮출 때 사용.
    • JPEG 압축 효과를 시뮬레이션할 때 사용.
  • 주요 파라미터:
    • image: JPEG 압축을 적용할 이미지
      • dtype가 반드시 uint8 이어야 함.
      • GPU에서 동작하지 않음. CPU에서만 실행할 것.
      • 채널수가 1 또는 3만 지원.
      • Tensor인 경우, [..., C, H, W] shape지원.
    • quality: JPEG 압축 품질 (낮을수록 더 많이 압축)
      • int형으로 1~100
from math import pi
from PIL import Image
import numpy as np
from torchvision.transforms.v2.functional import to_pil_image


# 다양한 JPEG 품질로 압축
high_quality   = jpeg(original_img, quality=95)  # 고품질
medium_quality = jpeg(original_img, quality=50)  # 중간품질
low_quality    = jpeg(original_img, quality=10)   # 저품질

print(f"Original image shape: {original_img.shape}/{type(original_img)}")
print(f"High quality shape  : {high_quality.shape}/{type(high_quality)}")
print(f"Medium quality shape: {medium_quality.shape}/{type(medium_quality)}")
print(f"Low quality shape   : {low_quality.shape}/{type(low_quality)}")

# PIL Image와 함께 사용
pil_from_tensor = to_pil_image(original_img)
jpeg_pil = jpeg(pil_from_tensor, quality=75)
print(f"with PIL Image   : {jpeg_pil.size}/{type(jpeg_pil)}")

# Result
# Original image shape: torch.Size([3, 512, 512])/<class 'torch.Tensor'>
# High quality shape  : torch.Size([3, 512, 512])/<class 'torch.Tensor'>
# Medium quality shape: torch.Size([3, 512, 512])/<class 'torch.Tensor'>
# Low quality shape   : torch.Size([3, 512, 512])/<class 'torch.Tensor'>
# with PIL Image   : (512, 512)/<class 'PIL.Image.Image'>

 

728x90