본문 바로가기
카테고리 없음

[DL] MultiLabelSoftMarginLoss: Multi-label classification

by ds31x 2024. 5. 30.

MultiLabelSoftMarginLoss

PyTorch 의 MultiLabelSoftMarginLoss는 Multi-label Classification 을 위한 loss function임.

 

Multi-Label Classification은 하나의 sample이 여러 개의 class 에 동시에 속할 수 있는 classification 문제임.

https://dsaint31.me/mkdocs_site/ML/ch02/ml_cls_types/#multilabel-classification

 

BME228

The Types of Classification Multiclass Classification Multinomial Classification이라고도 불림. Binary Classification 의 generalization으로 한 sample에 하나의 label값이 주어지지만, 해당 label값의 종류가 여러 class인 경우를

dsaint31.me


Loss function 의 역할

Loss function 은

  • model의 prediction이 training dataset과 얼마나 잘 일치하는지 (good fit)를
  • scalar 수치로 나타내는 function.

이 값을 최소화하는 방향으로 model의 parameters를 조정함으로써, model의 performance을 개선하게 됨.


MultiLabelSoftMarginLoss의 작동 원리

MultiLabelSoftMarginLoss는 각 클래스 레이블을 각각 독립적인 binary classification probelme으로 취급.

  • 이 function은 model의 logit vector 출력(클래스 각각에 대한 예측 값)을 입력으로 함.
  • 내부적으로 logistic function 를 적용하여 이를 [0, 1] 범위의 확률로 변환함.
  • 이 변환된 확률과 실제 레이블(각 클래스에 대한 0 또는 1의 값)을 사용하여 Binary Cross Entropy Loss를 각각 계산하여 합침.

수식

$$L(y, \hat{y}) = -\frac{1}{C} \sum_{i=1}^C [y_i \cdot \log(\sigma(\hat{y}_i)) + (1 - y_i) \cdot \log(1 - \sigma(\hat{y}_i))]$$

  • \( \sigma \) : logistic function (~sigmoid)
  • \( \hat{y}_i \) : predicted logit score.
  • \( y_i \) : label (or target)임.

class weight 적용: for imbalanced class.

MultiLabelSoftMarginLoss는 다른 loss functions처럼 선택적으로 각 클래스에 다른 가중치를 부여할 수 있는 기능을 제공함.

  • 이는 특정 클래스가 다른 클래스보다 더 중요하거나,
  • 데이터셋 내에서 클래스 불균형을 해소할 필요가 있을 때 유용.

코드 예제

import torch
import torch.nn as nn

# 모델 예측 로짓과 실제 레이블
logits = torch.tensor([[0.5, -1.0, 3.0], [1.5, -2.0, 0.0]], requires_grad=True)
labels = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float)

# 클래스별 가중치 설정
weights = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float)

# 손실 함수 초기화 및 가중치 적용
loss_function = nn.MultiLabelSoftMarginLoss(weight=weights)

# 손실 계산
loss = loss_function(logits, labels)
print("Loss:", loss.item())

다른 구현물 과 비교.

PyTorch에서 sigmoidBCELoss 의 조합을 통해서도 mulit-label classification를 수행할 수 있음.

이 방법은 각 클래스에 대한 독립적인 이진 분류 문제로 접근하는 것으로, 모델이 각 클래스에 대한 확률을 직접 출력하도록 함.

Binary classification 과의 차이점은 출력이 vector라는 점일 뿐, 거의 유사한 형태임.

 

차이점

  • BCELoss는 외부에서 시그모이드 활성화를 요구하며, MultiLabelSoftMarginLoss는 내부적으로 시그모이드 활성화를 적용.
  • BCELoss는 주로 일반적인 이진 분류에 사용되고, MultiLabelSoftMarginLoss는 각 인스턴스가 여러 레이블을 가질 수 있는 멀티 레이블 분류 문제에 적합.
  • BCELoss[0, 1] 범위의 확률을 입력으로 받고, MultiLabelSoftMarginLoss는 로짓 값을 입력으로 받음.

 

기능적으로 두 방식 사이에는 큰 차이가 없으며, 두 방식 모두 mulit-label classification를 위해 효과적으로 사용될 수 있습니다


같이 읽어보면 좋은 자료들

2024.05.23 - [Python] - [DL] Classification 을 위한 Activation Func. 와 Loss Func: PyTorch

 

[DL] Classification 을 위한 Activation Func. 와 Loss Func: PyTorch

PyTorch: Classification에서의 Output Func(~Activation Func.)와 Loss Func. 요약PyTorch는 다양한 종류의 손실 함수와 활성화 함수를 제공하는데,이 중에서 classification task를 수행하는 모델에서 사용되는 것들을

ds31x.tistory.com