본문 바로가기
카테고리 없음

[DL] PyTorch: Tensor 비교하기.

by ds31x 2024. 5. 16.

[DL] PyTorch: Tensor 비교하기

PyTorch에서 nn.Parameter 또는 tensor 객체 두 개가 같은 값을 가지는지 확인하는 방법은 텐서의 모든 요소가 동일한지 확인하는 것임.

 

이를 위해 제공되는 다음과 같은 함수 2개가 존재함.

  • torch.equal : 두 텐서의 모든 요소가 동일한지 확인
  • torch.allclose : 지정된 허용 오차 내에서 두 텐서가 거의 동일한지 확인.

1. torch.equal 사용

이 방법은 두 텐서가 완전히 동일한지를 확인.

import torch
import torch.nn as nn

# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param3 = nn.Parameter(torch.tensor([1.0, 2.0, 4.0]))

# 값 비교
print(torch.equal(param1, param2))  # True
print(torch.equal(param1, param3))  # False

2. torch.allclose 사용

이 방법은 두 텐서가 지정된 허용 오차 내에서 거의 동일한지를 확인.

이는 부동 소수점 연산의 미세한 차이로 인한 불일치를 허용할 수 있음 (권장).

import torch
import torch.nn as nn

# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0000001]))
param3 = nn.Parameter(torch.tensor([1.0, 2.0, 4.0]))

# 값 비교
print(torch.allclose(param1, param2))  # True, 기본 허용 오차 내
print(torch.allclose(param1, param3))  # False

2-1. torch.allclose의 허용 오차 설정

허용 오차를 설정하여 두 텐서가 거의 동일한지 확인할 수 있음.

import torch
import torch.nn as nn

# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0001]))

# 값 비교, 허용 오차 설정
print(torch.allclose(param1, param2, atol=1e-4))  # True, 허용 오차 내
print(torch.allclose(param1, param2, atol=1e-5))  # False, 허용 오차 밖

설명

  • torch.allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08):
    • 두 텐서가 상대적(rtol) 및 절대적(atol) 허용 오차 내에서 거의 동일한지 확인.

2-2. 상대적 오차와 절대적 오차

  • 상대적 오차와 절대적 오차는 두 값의 차이를 평가할 때 사용하는 두 가지 기준임.

2-2-1. 절대적 오차 (Absolute Tolerance, atol)

절대적 오차는 두 값의 차이 자체를 평가.

 

이는 다음과 같이 정의됩니다:

$$\text{절대적 오차} = |a - b|$$

여기서 a 와 b는 비교하려는 두 값으로 절대적 오차는 값의 크기에 상관없이 일정한 허용 오차를 제공하는 기준임.

 

예시

만약 a = 1000.0 이고 b = 1000.1 이라면, 절대적 오차는 다음과 같이 계산됩니다:

$$|1000.0 - 1000.1| = 0.1$$


2-2-2. 상대적 오차 (Relative Tolerance, rtol)

상대적 오차는 두 값의 차이를 값의 크기에 대한 비율로 평가함.

 

이는 다음과 같이 정의됩니다:

$$\text{상대적 오차} = \frac{|a - b|}{|b|}$$

여기서 a와 b는 비교하려는 두 값임.
상대적 오차는 값의 크기에 비례하여 허용 오차를 제공합니다.

 

예시

만약 a = 1000.0 이고 b = 1000.1이라면, 상대적 오차는 다음과 같이 계산:

$$\frac{|1000.0 - 1000.1|}{|1000.1|} \approx 0.0001$$


2-3. torch.allclose의 결합 기준

torch.allclose 함수는 두 가지 오차 기준을 결합하여 두 텐서가 얼마나 유사한지 평가함.

 

두 텐서의 각 요소 쌍에 대해 다음 조건이 성립하면 두 텐서가 유사하다고 반환:

$$|a_i - b_i| \leq \text{atol} + \text{rtol} \times |b_i|$$

 

여기서 a_ib_i는 비교하려는 두 텐서의 각 element임.

 

예제

두 텐서 tensor1tensor2가 있다고 가정합니다:

import torch

tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([1.00001, 2.00001, 3.00001])

print(torch.allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08))

 

여기서 tensor1의 각 요소 a_itensor2의 각 요소 b_i에 대해 다음 조건을 확인합니다:

$$|a_i - b_i| \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times |b_i|$$

 

각 요소 쌍에 대해 계산하면:

 

a_1 = 1.0 , b_1 = 1.00001:

$$|1.0 - 1.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 1.00001 \approx 1.00011 \times 10^{-5}$$

 

a_2 = 2.0, b_2 = 2.00001:

$$|2.0 - 2.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 2.00001 \approx 2.00011 \times 10^{-5}$$

 

a_3 = 3.0, b_3 = 3.00001

 

$$|3.0 - 3.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 3.00001 \approx 3.00011 \times 10^{-5}$$

 

모든 요소 쌍에 대해 위 조건이 성립하므로, torch.allcloseTrue를 반환합니다.


2-4. 오차 요약

torch.allclose는 다음의 두 기준을 결합하여 두 텐서가 거의 동일한지 확인.

  • 절대적 오차 (atol):
    • 두 값의 차이가 일정한 허용 오차 이내인지 확인합니다.
    • 값의 크기에 상관없이 일정한 기준을 제공.
  • 상대적 오차 (rtol):
    • 두 값의 차이가 값의 크기에 비례하여 허용 오차 이내인지 확인.
    • 값의 크기에 비례하는 기준을 제공.

결론

이 방법들을 사용하여 두 nn.Parameter 객체가 같은 값을 가지는지 확인할 수 있음
torch.equal은 완전히 동일한 값을 확인할 때, torch.allclose는 허용 오차 내에서 거의 동일한 값을 확인할 때 유용.