본문 바로가기
Python

[Tensor] NaN Safe Aggregation Functions

by ds31x 2024. 3. 20.

NaN (Not a Number) 값을 포함하는 Tensor 인스턴스에서

Aggregation Function을 사용할 때,

NaN을 무시 또는 특정값으로 처리하는 기능을 제공하는 함수.


NumPy

기존의 aggregaton function의 이름에 nan을 앞에 붙인 이름을 가지며, 수행 중 NaN을 무시함.

다음의 함수들이 대표적인 예임.

  • np.nansum, np.nanmean,
  • np.nanmax, np.nanmin, np.nanargmin, np.nanargmax,
  • np.nanmedian,
  • np.nanstd, np.nanvar,
  • np.nanprod,
  • np.nanquantile, np.nanpercentile

PyTorch

역시, 기존의 aggregaton function의 이름에 nan을 앞에 붙인 이름을 가지며, 수행 중 NaN을 무시함.

NumPy에 비해선 종류가 적은 편임.

  • torch.nansum, torch.nanmean,
  • np.nanmedian

TensorFlow

Tf에서의 Aggregation은

NumPy와 PyTorch와 꽤나 다른 형태이고 사용이 까다롭다보니,

최근 NumPy의 Aggregation function과 유사한 API의 함수들을 제공하는 다음의 module이 제공됨.

(아직 실험 단계임.)

tf.experimental.numpy

 

이를 사용하여 mean을 구하는 방식은 다음과 같음.

tf.experimental.numpy.nanmean(
    a, axis=None, dtype=None, keepdims=None
)

참고자료

https://numpy.org/doc/stable/reference/generated/numpy.nanmean.html

 

numpy.nanmean — NumPy v1.26 Manual

If out=None, returns a new array containing the mean values, otherwise a reference to the output array is returned. Nan is returned for slices that contain only NaNs.

numpy.org

https://pytorch.org/docs/stable/generated/torch.nanmean.html

 

torch.nanmean — PyTorch 2.2 documentation

Shortcuts

pytorch.org

https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/nanmean

 

tf.experimental.numpy.nanmean  |  TensorFlow v2.15.0.post1

TensorFlow variant of NumPy's nanmean.

www.tensorflow.org

https://gist.github.com/dsaint31x/ed03d40dc9643c842d603ab43d86fbb5

 

dl_tensor_nan_safe_aggregation.ipynb

dl_tensor_nan_safe_aggregation.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com