[DL] PyTorch: Tensor 비교하기
PyTorch에서 nn.Parameter
또는 tensor
객체 두 개가 같은 값을 가지는지 확인하는 방법은 텐서의 모든 요소가 동일한지 확인하는 것임.
이를 위해 제공되는 다음과 같은 함수 2개가 존재함.
torch.equal
: 두 텐서의 모든 요소가 동일한지 확인torch.allclose
: 지정된 허용 오차 내에서 두 텐서가 거의 동일한지 확인.
1. torch.equal
사용
이 방법은 두 텐서가 완전히 동일한지를 확인.
import torch
import torch.nn as nn
# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param3 = nn.Parameter(torch.tensor([1.0, 2.0, 4.0]))
# 값 비교
print(torch.equal(param1, param2)) # True
print(torch.equal(param1, param3)) # False
2. torch.allclose
사용
이 방법은 두 텐서가 지정된 허용 오차 내에서 거의 동일한지를 확인.
이는 부동 소수점 연산의 미세한 차이로 인한 불일치를 허용할 수 있음 (권장).
import torch
import torch.nn as nn
# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0000001]))
param3 = nn.Parameter(torch.tensor([1.0, 2.0, 4.0]))
# 값 비교
print(torch.allclose(param1, param2)) # True, 기본 허용 오차 내
print(torch.allclose(param1, param3)) # False
2-1. torch.allclose
의 허용 오차 설정
허용 오차를 설정하여 두 텐서가 거의 동일한지 확인할 수 있음.
import torch
import torch.nn as nn
# nn.Parameter 객체 생성
param1 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
param2 = nn.Parameter(torch.tensor([1.0, 2.0, 3.0001]))
# 값 비교, 허용 오차 설정
print(torch.allclose(param1, param2, atol=1e-4)) # True, 허용 오차 내
print(torch.allclose(param1, param2, atol=1e-5)) # False, 허용 오차 밖
설명
torch.allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08)
:- 두 텐서가 상대적(
rtol
) 및 절대적(atol
) 허용 오차 내에서 거의 동일한지 확인.
- 두 텐서가 상대적(
2-2. 상대적 오차와 절대적 오차
- 상대적 오차와 절대적 오차는 두 값의 차이를 평가할 때 사용하는 두 가지 기준임.
2-2-1. 절대적 오차 (Absolute Tolerance, atol
)
절대적 오차는 두 값의 차이 자체를 평가.
이는 다음과 같이 정의됩니다:
$$\text{절대적 오차} = |a - b|$$
여기서 a 와 b는 비교하려는 두 값으로 절대적 오차는 값의 크기에 상관없이 일정한 허용 오차를 제공하는 기준임.
예시
만약 a = 1000.0
이고 b = 1000.1
이라면, 절대적 오차는 다음과 같이 계산됩니다:
$$|1000.0 - 1000.1| = 0.1$$
2-2-2. 상대적 오차 (Relative Tolerance, rtol
)
상대적 오차는 두 값의 차이를 값의 크기에 대한 비율로 평가함.
이는 다음과 같이 정의됩니다:
$$\text{상대적 오차} = \frac{|a - b|}{|b|}$$
여기서 a와 b는 비교하려는 두 값임.
상대적 오차는 값의 크기에 비례하여 허용 오차를 제공합니다.
예시
만약 a = 1000.0
이고 b = 1000.1
이라면, 상대적 오차는 다음과 같이 계산:
$$\frac{|1000.0 - 1000.1|}{|1000.1|} \approx 0.0001$$
2-3. torch.allclose
의 결합 기준
torch.allclose
함수는 두 가지 오차 기준을 결합하여 두 텐서가 얼마나 유사한지 평가함.
두 텐서의 각 요소 쌍에 대해 다음 조건이 성립하면 두 텐서가 유사하다고 반환:
$$|a_i - b_i| \leq \text{atol} + \text{rtol} \times |b_i|$$
여기서 a_i
와 b_i
는 비교하려는 두 텐서의 각 element임.
예제
두 텐서 tensor1
과 tensor2
가 있다고 가정합니다:
import torch
tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([1.00001, 2.00001, 3.00001])
print(torch.allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08))
여기서 tensor1
의 각 요소 a_i
와 tensor2
의 각 요소 b_i
에 대해 다음 조건을 확인합니다:
$$|a_i - b_i| \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times |b_i|$$
각 요소 쌍에 대해 계산하면:
a_1 = 1.0 , b_1 = 1.00001
:
$$|1.0 - 1.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 1.00001 \approx 1.00011 \times 10^{-5}$$
a_2 = 2.0, b_2 = 2.00001:
$$|2.0 - 2.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 2.00001 \approx 2.00011 \times 10^{-5}$$
a_3 = 3.0, b_3 = 3.00001
$$|3.0 - 3.00001| = 0.00001 \leq 1 \times 10^{-8} + 1 \times 10^{-5} \times 3.00001 \approx 3.00011 \times 10^{-5}$$
모든 요소 쌍에 대해 위 조건이 성립하므로, torch.allclose
는 True
를 반환합니다.
2-4. 오차 요약
torch.allclose는 다음의 두 기준을 결합하여 두 텐서가 거의 동일한지 확인.
- 절대적 오차 (
atol
):- 두 값의 차이가 일정한 허용 오차 이내인지 확인합니다.
- 값의 크기에 상관없이 일정한 기준을 제공.
- 상대적 오차 (
rtol
):- 두 값의 차이가 값의 크기에 비례하여 허용 오차 이내인지 확인.
- 값의 크기에 비례하는 기준을 제공.
결론
이 방법들을 사용하여 두 nn.Parameter
객체가 같은 값을 가지는지 확인할 수 있음torch.equal
은 완전히 동일한 값을 확인할 때, torch.allclose
는 허용 오차 내에서 거의 동일한 값을 확인할 때 유용.