PyTorch의
torchvision.transforms
:
이미지 전처리와 데이터 증강을 위한 도구
torchvision.transforms
는
- PyTorch에서 제공하는 이미지 전처리 및 데이터 증강을 위한 module.
- 이 모듈은 이미지 데이터를 이용한 딥러닝 모델의 학습 효율을 높이고 데이터 준비 과정을 단순화하는 데 사용됨.
1. torchvision.transforms란 무엇인가?
torchvision.transforms
의 주요 역할:
- 이미지 전처리: 크기 조정, 자르기, 회전.
- Data Augmentation (데이터 증강)
- 전처리 파이프라인 구축: 여러 변환을 조합하여 효율적인 데이터 전처리 루틴을 생성.
2. 주요 변환 메서드와 사용법
2-1. 이미지 크기 조정
transforms.Resize
:
- 이미지를 지정된 크기로 변경.
transform = transforms.Resize((128, 128))
resized_image = transform(image)
transforms.CenterCrop
:
- 중앙을 기준으로 이미지를 cropping (자름).
transform = transforms.CenterCrop((100, 100))
cropped_image = transform(image)
transforms.RandomCrop
:
- 임의의 위치에서 이미지를 자름.
padding
으로 crop 전에 padding을 수행.
transform = transforms.RandomCrop((100, 100), padding=4)
random_cropped_image = transform(image)
2-2. 텐서 변환
transforms.ToTensor
:
- 이미지를 PyTorch 텐서로 변환.
- 픽셀 값을
[0, 1]
로 normalization(정규화):- 입력 데이터가 Pillow 의 Image 객체인 경우 수행됨.
- NumPy의 ndarray이면서 float 형이 아닐 경우 수행됨.
- 입력데이터가 Pillow 의 Image 객체인 경우,
- 입력이 Width, Height, Channel 이라고 가정하고
- 출력은 Channel, Width, Height로 변경됨.
transform = transforms.ToTensor()
tensor_image = transform(image)
transforms.Normalize
:
- 텐서를 normalization하기 위해 standardization을 수행.
- 각 채널에 대해
(x - mean) / std
를 적용.
transform = transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
)
normalized_image = transform(tensor_image)
2-3. 데이터 증강
transforms.RandomHorizontalFlip
:
- 이미지 좌우 반전을 확률적으로 적용.
transform = transforms.RandomHorizontalFlip(p=0.5)
flipped_image = transform(image)
transforms.RandomRotation
:
- 이미지를 랜덤 각도로 회전합니다.
transform = transforms.RandomRotation(degrees=30)
rotated_image = transform(image)
transforms.ColorJitter
:
- 밝기, 대비, 채도, 색조를 조정합니다.
- Parameters: 0에서 1 사이의 값을 사용하며, 지정된 범위 내에서 랜덤하게 변환이 적용됨.
brightness
: 밝기 조정 범위 (예: 0.2는 원본 대비 0.8~1.2배 사이의 밝기로 변환)contrast
: 대비 조정 범위saturation
: 채도 조정 범위hue
: 색조 조정 범위
transform = transforms.ColorJitter(brightness=0.2, contrast=0.3)
jittered_image = transform(image)
2-4. 특수 변환
transforms.Grayscale
:
- 이미지를 흑백으로 변환합니다.
transform = transforms.Grayscale(num_output_channels=1)
gray_image = transform(image)
transforms.Lambda
:
- 사용자 정의 변환을 구현한 lambda expression을 적용.
transform = transforms.Lambda(lambda img: img.rotate(45))
rotated_image = transform(image)
다음은 Lambda transform을 테스트하는 전체 코드임.
import PIL
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
np_img = np.zeros((10, 10, 3), dtype=np.uint8)
np_img[5,:,:] = 255
img = Image.fromarray(np_img)
# # 샘플 이미지 불러오기
# img = Image.open("IMG_2826.jpeg")
# Lambda transform 정의 및 적용
transform = transforms.Lambda(
lambda img: img.rotate(
20,resample=PIL.Image.Resampling.BILINEAR
)
)
rotated_image = transform(img)
# 결과 확인
plt.imshow(np.array(rotated_image))
https://dsaint31.tistory.com/236
2023.07.07 - [Python] - [Python] lambda expression and map, filter, reduce.
3. 변환 파이프라인 구축
torchvision.transforms.Compose
를 통해 여러 변환을 연결하여 pipeline(파이프라인)을 만들 수 있음.
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transformed_image = transform(image)
https://dsaint31.tistory.com/829
4. 커스텀 변환 작성법
4-1. 클래스 기반 커스텀 변환
__call__
메서드를 가진 클래스를 작성하여 새로운 변환을 정의.
class CustomTransform:
def __call__(self, img):
return img.point(lambda x: x * 1.5) # 밝기를 1.5배 증가
4-2. 함수 기반 커스텀 변환
단순한 변환은 함수로 구현할 수 있음.
transforms.Lambda
를 활용하여 간단히 추가 가능.
def custom_transform(img):
return img.convert("L") # 흑백으로 변환
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.Lambda(custom_transform),
transforms.ToTensor()
])
5. DataLoader와 결합하여 GPU로 데이터 옮기기
torchvision.transforms
는
- DataLoader와 결합하여 대규모 데이터셋에 자동으로 변환을 적용.
- DataLoader는 Dataset 객체를 통해 데이터를 로드할 때, CPU를 이용함.
- Transform 자체는 CPU에서 동작함.
- DataLoader 가 Transform이 적용된 batch를 반환하면 이를 GPU로 이동시켜야 한다.
때문에, 학습 중 GPU를 효율적으로 활용하려면 최종 데이터(Transform이 적용된 Tensor 객체)를 GPU로 이동시켜야함.
5-1. 데이터셋 및 DataLoader 정의
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
dataset = ImageFolder(root='path_to_images', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
5-2. GPU로 데이터 전송 및 학습 루프
import torch
import torch.nn as nn
import torch.optim as optim
# 간단한 모델 정의
model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16 * 128 * 128, 10)
).to('cuda') # 모델을 GPU로 이동
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 학습 루프
for epoch in range(5):
for images, labels in dataloader:
images = images.to('cuda', non_blocking=True)
labels = labels.to('cuda', non_blocking=True)
# 순전파 및 역전파
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')
6. 실전 팁
- Data Augmentation과 Pre-Processing 분리:
- Data Augmentation은 Training 단계에서만 사용
- Evaluation/Test 단계에서는 Pre-Processing만 적용.
- 고정 메모리와 비동기 전송 활용:
pin_memory=True
와non_blocking=True
를 사용해- GPU 전송 성능을 최적화.
- GPU 활용 여부 확인:
- 학습 전에 GPU가 사용 가능한지 확인.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
같이 보면 좋은 자료
2024.04.09 - [Python] - [PyTorch] Dataset and DataLoader
'Python' 카테고리의 다른 글
[Py] io.StringIO 와 io.BytesIO (0) | 2024.12.03 |
---|---|
[Py] Serialization of Python: pickle (1) | 2024.11.27 |
[Py] Context Manager: with statement! (0) | 2024.11.27 |
[Py] Higher-order Function (고차함수) (1) | 2024.11.20 |
[Py] 숫자 야구 게임: structured programming, type annotation, and OOP (0) | 2024.11.20 |