MultiLabelSoftMarginLoss
PyTorch 의 MultiLabelSoftMarginLoss
는 Multi-label Classification 을 위한 loss function임.
Multi-Label Classification은 하나의 sample이 여러 개의 class 에 동시에 속할 수 있는 classification 문제임.
https://dsaint31.me/mkdocs_site/ML/ch02/ml_cls_types/#multilabel-classification
Loss function 의 역할
Loss function 은
- model의 prediction이 training dataset과 얼마나 잘 일치하는지 (good fit)를
- scalar 수치로 나타내는 function.
이 값을 최소화하는 방향으로 model의 parameters를 조정함으로써, model의 performance을 개선하게 됨.
MultiLabelSoftMarginLoss의 작동 원리
MultiLabelSoftMarginLoss
는 각 클래스 레이블을 각각 독립적인 binary classification probelme으로 취급.
- 이 function은 model의 logit vector 출력(클래스 각각에 대한 예측 값)을 입력으로 함.
- 내부적으로 logistic function 를 적용하여 이를
[0, 1]
범위의 확률로 변환함. - 이 변환된 확률과 실제 레이블(각 클래스에 대한
0
또는1
의 값)을 사용하여 Binary Cross Entropy Loss를 각각 계산하여 합침.
수식
$$L(y, \hat{y}) = -\frac{1}{C} \sum_{i=1}^C [y_i \cdot \log(\sigma(\hat{y}_i)) + (1 - y_i) \cdot \log(1 - \sigma(\hat{y}_i))]$$
- \( \sigma \) : logistic function (~sigmoid)
- \( \hat{y}_i \) : predicted logit score.
- \( y_i \) : label (or target)임.
class weight 적용: for imbalanced class.
MultiLabelSoftMarginLoss
는 다른 loss functions처럼 선택적으로 각 클래스에 다른 가중치를 부여할 수 있는 기능을 제공함.
- 이는 특정 클래스가 다른 클래스보다 더 중요하거나,
- 데이터셋 내에서 클래스 불균형을 해소할 필요가 있을 때 유용.
코드 예제
import torch
import torch.nn as nn
# 모델 예측 로짓과 실제 레이블
logits = torch.tensor([[0.5, -1.0, 3.0], [1.5, -2.0, 0.0]], requires_grad=True)
labels = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float)
# 클래스별 가중치 설정
weights = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float)
# 손실 함수 초기화 및 가중치 적용
loss_function = nn.MultiLabelSoftMarginLoss(weight=weights)
# 손실 계산
loss = loss_function(logits, labels)
print("Loss:", loss.item())
다른 구현물 과 비교.
PyTorch에서 sigmoid
와 BCELoss
의 조합을 통해서도 mulit-label classification를 수행할 수 있음.
이 방법은 각 클래스에 대한 독립적인 이진 분류 문제로 접근하는 것으로, 모델이 각 클래스에 대한 확률을 직접 출력하도록 함.
Binary classification 과의 차이점은 출력이 vector라는 점일 뿐, 거의 유사한 형태임.
차이점
BCELoss
는 외부에서 시그모이드 활성화를 요구하며,MultiLabelSoftMarginLoss
는 내부적으로 시그모이드 활성화를 적용.BCELoss
는 주로 일반적인 이진 분류에 사용되고,MultiLabelSoftMarginLoss
는 각 인스턴스가 여러 레이블을 가질 수 있는 멀티 레이블 분류 문제에 적합.BCELoss
는[0, 1]
범위의 확률을 입력으로 받고,MultiLabelSoftMarginLoss
는 로짓 값을 입력으로 받음.
기능적으로 두 방식 사이에는 큰 차이가 없으며, 두 방식 모두 mulit-label classification를 위해 효과적으로 사용될 수 있습니다
같이 읽어보면 좋은 자료들
2024.05.23 - [Python] - [DL] Classification 을 위한 Activation Func. 와 Loss Func: PyTorch