본문 바로가기
목차
ML

[DL] MultiLabelSoftMarginLoss: Multi-label classification

by ds31x 2024. 5. 30.
728x90
반응형

Vectorized Binary Cross-Entropy : MultiLabelSoftMarginLoss

MultiLabelSoftMarginLoss

PyTorch 의 MultiLabelSoftMarginLoss는 Multi-label Classification 을 위한 loss function임.

 

Multi-Label Classification은 하나의 sample이 여러 개의 class 에 동시에 속할 수 있는 classification 문제임.

https://dsaint31.me/mkdocs_site/ML/ch02/ml_cls_types/#multilabel-classification

 

BME

OvO OvR binary-classification multiclass multilabel multioutput Types of Classification Binary Classification Binary Classification (이진분류)는 각 sample에 2개의 label 중 하나를 할당하는 task임. 일반적으로 특정 class에 속하는

dsaint31.me


Loss function 의 역할

Loss function 은

  • model의 prediction이 training dataset과 얼마나 잘 일치하는지 (good fit)를
  • scalar 수치로 나타내는 function.

이 값을 최소화하는 방향으로 model의 parameters를 조정함으로써, model의 performance을 개선하게 됨.


MultiLabelSoftMarginLoss의 작동 원리

MultiLabelSoftMarginLoss각 클래스 레이블을 각각 독립적인 binary classification probelme으로 취급.

  • 이 function은 model의 logit vector 출력(클래스 각각에 대한 예측 값)을 입력으로 함.
  • 내부적으로 logistic function 를 적용하여 이를 [0, 1] 범위의 확률로 변환함.
  • 이 변환된 확률과 실제 레이블(각 클래스에 대한 0 또는 1의 값)을 사용하여 Binary Cross Entropy Loss를 각각 계산하여 합침.

수식

$$L(y, \hat{y}) = -\frac{1}{C} \sum_{i=1}^C [y_i \cdot \log(\sigma(\hat{y}_i)) + (1 - y_i) \cdot \log(1 - \sigma(\hat{y}_i))]$$

  • \( \sigma \) : logistic function (~sigmoid)
  • \( \hat{y}_i \) : predicted logit score로 logistic function을 거쳐 확률값이 됨 (sigmoid 이전의 raw input).
  • \( y_i \) : label (or target)임. 0 또는 1 을 가짐.
  • \( C \) : number of classes

binary cross-entropy의 vectorized version이라고 봐도 된다: multi-class classification을 OvR로 구현할 때나 multi-label classifcation에서 loss function으로 사용 가능함.


class weight 적용: for imbalanced class.

MultiLabelSoftMarginLoss는 다른 loss functions처럼 선택적으로 각 클래스에 다른 가중치를 부여할 수 있는 기능을 제공함.

  • 이는 특정 클래스가 다른 클래스보다 더 중요하거나,
  • 데이터셋 내에서 클래스 불균형을 해소할 필요가 있을 때 유용.

코드 예제

import torch
import torch.nn as nn

# 모델 예측 로짓과 실제 레이블
logits = torch.tensor([[0.5, -1.0, 3.0], [1.5, -2.0, 0.0]], requires_grad=True)
labels = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float)

# 클래스별 가중치 설정
weights = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float)

# 손실 함수 초기화 및 가중치 적용
loss_function = nn.MultiLabelSoftMarginLoss(weight=weights)

# 손실 계산
loss = loss_function(logits, labels)
print("Loss:", loss.item())

다른 구현물 과 비교.

PyTorch에서 sigmoidBCELoss 의 조합을 통해서도 mulit-label classification를 수행할 수 있음.

이 방법은 각 클래스에 대한 독립적인 이진 분류 문제로 접근하는 것으로, 모델이 각 클래스에 대한 확률을 직접 출력하도록 함.

Binary classification 과의 차이점은 출력이 vector라는 점일 뿐, 거의 유사한 형태임.

 

차이점

  • BCELoss는 외부에서 시그모이드 활성화를 요구하며, MultiLabelSoftMarginLoss내부적으로 시그모이드 활성화를 적용.
  • BCELoss는 주로 일반적인 이진 분류에 사용되고, MultiLabelSoftMarginLoss는 각 인스턴스가 여러 레이블을 가질 수 있는 multi-label 분류 문제 또는 OvR 기반의 multi-class 분류 문제에 적합.
  • BCELoss[0, 1] 범위의 확률을 입력으로 받고, MultiLabelSoftMarginLoss는 로짓 값을 입력으로 받음.

 

기능적으로 두 방식 사이에는 큰 차이가 없으며, 두 방식 모두 mulit-label classification를 위해 효과적으로 사용될 수 있습니다


같이 읽어보면 좋은 자료들

2024.05.23 - [Python] - [DL] Classification 을 위한 Activation Func. 와 Loss Func: PyTorch

 

[DL] Classification 을 위한 Activation Func. 와 Loss Func: PyTorch

PyTorch: Classification에서의 Output Func(~Activation Func.)와 Loss Func. 요약PyTorch는 다양한 종류의 loss function(손실 함수)와 activation function(활성화 함수)를 제공하는데,이 중에서 classification task를 수행하는

ds31x.tistory.com


728x90