본문 바로가기
목차
ML

JAX (Just After eXecution)소개

by ds31x 2026. 1. 16.
728x90
반응형

JAX는

  • 수학적 함수(pure function)를 개발자가 작성하면,
  • 이를 인자로 받아 자동 미분, JIT 컴파일, 벡터화·병렬화를 적용한 새로운 함수를 반환해 주는 function들을 제공.
  • 이를 고성능 수치 계산을 가능하게 하는 함수 변환 라이브러리임ㅣ 딥러닝 프레임워크의 기반이 되어주는 라이브러리로 볼 수 있음.

JAX는

  • 기존 딥러닝 프레임워크와 사고방식 자체가 다른 도구이며,
  • 특히 이론·연구·수식 중심 작업에 매우 잘 맞는 도구로 알려짐.

1. JAX (Just After eXecution) 소개

JAX의 fully qualified name은 다음과 같음:

  • Just After eXecution
  • 이는 “코드를 실행한 직후에 필요한 변환(미분, 컴파일 등)을 수행한다”는 철학을 반영.

JAX는 Google에서 개발·유지되고 있으며,
특히 Google Research / DeepMind 계열 연구자들을 중심으로 발전하고 있음.


2. JAX는 무엇인가

JAX는 다음과 같이 정의할 수 있습니다.

  • JAX는 NumPy와 유사한 문법을 사용하여 수치 계산을 작성하고,
  • 그 계산을 자동 미분, JIT 컴파일, 벡터화, 병렬화된 “새로운 함수”로 변환해 주는 라이브러리 임.

중요한 점은 다음과 같음.

  • JAX의 기본 입력은 Python 함수
  • JAX의 출력 역시 Python 함수
  • 계산 그래프는 사용자에게 노출되지 않음

 

JAX에서

  • JIT는 “함수를 컴파일 가능한 형태로 감싸는 도구”이고,
  • XLA는 그 함수를 실제 하드웨어에서 빠르게 실행하도록 최적화하는 컴파일러임: 내부적으로 IR로 사용되는 계산 그래프를 컴파일.

3. JAX와 계산 그래프의 관계

3.1 계산 그래프란 무엇인가

계산 그래프(computation graph)는

  • 연산을 노드(node),
  • 데이터 흐름을 엣지(edge)로 표현한 구조(DAG, Directed Acyclic Graph).

graph 와 Computational Graph에 대한 보다 자세한 자료

대부분의 딥러닝 프레임워크는
내부적으로 계산 그래프를 사용하고 있음.

3.2 JAX의 관점

JAX에서 계산 그래프는 다음과 같이 취급됩니다.

  • 계산 그래프는 사용자가 직접 만드는 대상이 아니라,
  • 함수를 변환·실행하는 과정에서 내부적으로 생성되는 중간 표현(IR)임.
  • JAX는 계산 그래프를 직접 실행하지 않고,
  • 해당 그래프의 중간 표현(IR)을 XLA 컴파일러에 넘겨 하드웨어별로 최적화된 기계 코드를 생성한 뒤 이를 실행.

예를 들어,

from jax import jit
import jax.numpy as jnp

def f(x):
    return jnp.sin(x) + x**2

g = jit(f)
  • jit(f)의 반환값은 새로운 함수 g
  • 이 시점에는 그래프도, 컴파일도 아직 수행되지 않음
y = g(x)

이 첫 호출 시점에 다음 과정이 이루어짐 (JIT).

  1. 함수 f를 추적(tracing)
  2. 연산을 중간 표현(IR, Jaxpr) 으로 변환
  3. IR을 XLA를 통해 기계 코드로 컴파일
  4. 컴파일된 코드를 실행
  5. 결과 값을 반환

즉,

  • 계산 그래프는 “함수 내부 구현”이며,
  • 사용자는 오직 함수만을 다룬다.

이런 의미에서

  • JAX는 전통적인 의미의 Define by Run (TF 1.x)도, Define and Run (PyTorch, TF 2.x)도 아니다.
  • 개념적으로는 “Define by Transformation” 이라고 해야 함.

2024.03.28 - [Python] - [DL] Define and Run vs. Define by Run

 

[DL] Define and Run vs. Define by Run

Deep Learning (DL) Framework의 동작방식을 비교하는 용어. DL Model의 구축과 실행이 어떻게 이루어지는지로 구분됨. Define and Run DL Model을 구축 (= Computational Graph)이 먼저 이루어지고, 이후 input tesnsor를 정

ds31x.tistory.com


 

4. 객체 + 상태 중심 vs 함수 + 데이터 중심

4.1 TensorFlow / PyTorch의 전통적 방식

TensorFlow 2.x와 PyTorch는 다음과 같은 사고방식을 따른다.

  • 모델은 객체(object)
  • 모델의 파라미터와 상태는 객체 내부에 존재
  • 연산은 객체의 메서드 호출로 수행
y = model(x)

이때 사용자는 모델 내부에 어떤 파라미터와 상태가 존재하는지를 명시적으로 전달하지 않으며,
연산 과정에서 객체 내부 상태가 암묵적으로 참조되거나 변경될 수 있다


4.2 JAX의 방식

JAX는 모델과 파라미터를 명확히 분리하여 표현한다.

y = f(params, x)
  • f : 입력과 파라미터를 받아 출력을 계산하는 순수 함수(pure function)
  • params : 명시적으로 분리된 데이터로 모델의 파라미터를 다룸
  • 상태 변경은 새로운 데이터를 반환하는 방식으로 표현

JAX에서는

  • 모델의 구조는 함수로,
  • 모델의 상태는 데이터로 분리하여 다룬다.

즉, JAX는 “함수 + 데이터” 중심 시스템이다.

이는 다음의 수학적 표현에 대응됨:
$$
y = f(\theta, x)
$$


5. 자동 미분과 함수 변환

JAX의 자동 미분은 다음과 같이 이해할 수 있음:

from jax import grad

df = grad(f)
  • grad(f)의 결과는 미분을 계산하는 새로운 함수
  • 미분 과정에서 필요한 계산 그래프는 함수 호출 시 내부적으로 생성됨

따라서,

  • JAX에서 미분은 “그래프를 만든다”가 아니라
  • “미분된 함수를 생성”

PyTorch에서의 AutoGrad는 다음을 참고:


6. JAX의 위치: 딥러닝 프레임워크인가?

흔히, DL 프레임워크로 TensorFlow, PyTorch와 함께 JAX가 꼽히는 경우가 많음.

(물론 2개만 고를 때는 빠짐)

 

엄밀히 말하면 JAX는

  • 딥러닝 프레임워크라기보다
  • 미분 가능한 수치 계산을 위한 기반 시스템이다.

그래서 실제 신경망 학습에서는 보통 다음과 같은 라이브러리를 함께 사용한다.

  • Flax / Haiku : 신경망 구조 정의
  • Optax : 옵티마이저
  • NumPyro : 확률 프로그래밍

7. 공식 튜토리얼 및 학습 자료

7.1 공식 사이트

JAX 공식 문서

https://jax.readthedocs.io/

 

JAX GitHub
https://github.com/google/jax


7.2 튜토리얼들 

간략하게만 살펴봤을 뿐, 본인도 다 보진 못한 상태임...

 

JAX 101 / Quickstart
다음은 비공식 번역 정리된 곳임:
https://rtd-tutorial-ybeen.readthedocs.io/en/latest/JAX101/index.html?utm_source=chatgpt.com

 

Autodiff Cookbook (자동 미분 예제 모음)
https://kolonist26-jax-kr.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html?utm_source=chatgpt.com


8. JAX 의 주요 개념

하도 자주 나와서 정리만 해보고 어떤 녀석인지만 파악한 상태임...

8.1 NumPy 에서 JAX NumPy 로 확장:

  • jax.numpy 사용
  • 배열은 immutable이라는 점 이해

8.2 grad : Gradient 함수

  • 미분된 함수를 반환하는 함수
  • Gradient를 계산할 수 있게 해주는 핵심 기능.

8.3 jit

  • 실제 넘겨진 function이 언제 컴파일되는지
  • 첫 호출과 이후 호출의 차이

8.4 vmap

  • for-loop 없는 벡터화 사고 방식

8.5 간단한 모델 직접 작성

  • f(params, x) 형태 유지
  • 옵티마이저는 Optax 사용

8.6 Flax / Haiku로 확장

  • 신경망 구조 추상화 학습

같이 보면 좋은 자료들

2025.02.28 - [CE] - [DL] GPU Acceleration 기술 소개

 

[DL] GPU Acceleration 기술 소개

1. 대표적 GPU 가속 기술현재 가장 많이 사용되고 있는 GPU기반의 가속기술은 다음과 같음:Compute Unified Device Architecture (CUDA): NVIDIAMetal Performance Shaders (MPS): AppleRadeon Open Compute (ROCm): AMDDirect Machine Learn

ds31x.tistory.com

2024.07.08 - [ML] - [ML] History of ML (작성중)

 

[ML] History of ML (작성중)

19401943 : ANN (Artificial Neural Network) 시작McCulloch, PittsA logical calculus of ideas immanent in nervous activity1949 : Weighting 변화를 통한 학습Donald Olding HebbThe Organization of Behavior: A Neuropsychological Theory19501956 : AI (Artif

ds31x.tistory.com

https://dsaint31.me/mkdocs_site/ML/ch08/reverse_mode_autodiff/

 

BME

Reverse-Mode Autodiff (Auto-Differentiation) Reverse-mode autodiff 는 TensorFlow, PyTorch 등에서 gradient를 구하는 back-propagation 수행에서 필요한 미분(differentiation)을 구하기 위해 사용되는 auto differentiation의 한 기법

dsaint31.me


 

728x90