
JAX는
- 수학적 함수(pure function)를 개발자가 작성하면,
- 이를 인자로 받아 자동 미분, JIT 컴파일, 벡터화·병렬화를 적용한 새로운 함수를 반환해 주는 function들을 제공.
- 이를 고성능 수치 계산을 가능하게 하는 함수 변환 라이브러리임ㅣ 딥러닝 프레임워크의 기반이 되어주는 라이브러리로 볼 수 있음.
JAX는
- 기존 딥러닝 프레임워크와 사고방식 자체가 다른 도구이며,
- 특히 이론·연구·수식 중심 작업에 매우 잘 맞는 도구로 알려짐.
1. JAX (Just After eXecution) 소개
JAX의 fully qualified name은 다음과 같음:
- Just After eXecution
- 이는 “코드를 실행한 직후에 필요한 변환(미분, 컴파일 등)을 수행한다”는 철학을 반영.
JAX는 Google에서 개발·유지되고 있으며,
특히 Google Research / DeepMind 계열 연구자들을 중심으로 발전하고 있음.
2. JAX는 무엇인가
JAX는 다음과 같이 정의할 수 있습니다.
- JAX는 NumPy와 유사한 문법을 사용하여 수치 계산을 작성하고,
- 그 계산을 자동 미분, JIT 컴파일, 벡터화, 병렬화된 “새로운 함수”로 변환해 주는 라이브러리 임.
중요한 점은 다음과 같음.
- JAX의 기본 입력은 Python 함수
- JAX의 출력 역시 Python 함수
- 계산 그래프는 사용자에게 노출되지 않음
JAX에서
- JIT는 “함수를 컴파일 가능한 형태로 감싸는 도구”이고,
- XLA는 그 함수를 실제 하드웨어에서 빠르게 실행하도록 최적화하는 컴파일러임: 내부적으로 IR로 사용되는 계산 그래프를 컴파일.
2024.03.21 - [Python] - [DL] PyTorch: TPU사용하기 - XLA
[DL] PyTorch: TPU사용하기 - XLA
https://github.com/pytorch/xla GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)Enabling PyTorch on XLA Devices (e.g. Google TPU). Contribute to pytorch/xla development by creating an account on GitHub.github.com 다음 문서도 참
ds31x.tistory.com
2024.01.18 - [CE] - [CE] Compilation 의 종류
[CE] Compilation 의 종류
기존의 compilation원래 compilation(컴파일)은programming language로 작성된 소스코드(source code, 원시코드)를타겟 하드웨어에서 동작할 수 있는 기계 코드 (machine code, binary code, opcode)로 바꾸어주는 것을 의
ds31x.tistory.com
3. JAX와 계산 그래프의 관계
3.1 계산 그래프란 무엇인가
계산 그래프(computation graph)는
- 연산을 노드(node),
- 데이터 흐름을 엣지(edge)로 표현한 구조(DAG, Directed Acyclic Graph).
graph 와 Computational Graph에 대한 보다 자세한 자료
https://dsaint31.me/mkdocs_site/ML/ch08/datastructure_graph/
BME
Graph Network (object들의 연결관계)를 나타내는 자료구조의 일종으로 node(or vertex, 정점)와 edge(or connection)로 구성된다. Object(node로 표현됨)들의 관계 등을 연결시켜서 Network로 표현해주는 모델. 여러
dsaint31.me
https://dsaint31.me/mkdocs_site/ML/ch08/back_propagation/#computational-graph
BME
Back propagation (역전파, 오차 역전파) 딥러닝 모델을 학습시키기 위한 핵심 알고리즘. Back propagation은 다음 2가지를 조합하여 ANN을 학습시킴. "Reverse-mode AutoDiff" (Reverse-mode automatic differentiation) "Gradien
dsaint31.me
대부분의 딥러닝 프레임워크는
내부적으로 계산 그래프를 사용하고 있음.
3.2 JAX의 관점
JAX에서 계산 그래프는 다음과 같이 취급됩니다.
- 계산 그래프는 사용자가 직접 만드는 대상이 아니라,
- 함수를 변환·실행하는 과정에서 내부적으로 생성되는 중간 표현(IR)임.
- JAX는 계산 그래프를 직접 실행하지 않고,
- 해당 그래프의 중간 표현(IR)을 XLA 컴파일러에 넘겨 하드웨어별로 최적화된 기계 코드를 생성한 뒤 이를 실행.
예를 들어,
from jax import jit
import jax.numpy as jnp
def f(x):
return jnp.sin(x) + x**2
g = jit(f)
jit(f)의 반환값은 새로운 함수g- 이 시점에는 그래프도, 컴파일도 아직 수행되지 않음
y = g(x)
이 첫 호출 시점에 다음 과정이 이루어짐 (JIT).
- 함수
f를 추적(tracing) - 연산을 중간 표현(IR, Jaxpr) 으로 변환
- IR을 XLA를 통해 기계 코드로 컴파일
- 컴파일된 코드를 실행
- 결과 값을 반환
즉,
- 계산 그래프는 “함수 내부 구현”이며,
- 사용자는 오직 함수만을 다룬다.
이런 의미에서
- JAX는 전통적인 의미의 Define by Run (TF 1.x)도, Define and Run (PyTorch, TF 2.x)도 아니다.
- 개념적으로는 “Define by Transformation” 이라고 해야 함.
2024.03.28 - [Python] - [DL] Define and Run vs. Define by Run
[DL] Define and Run vs. Define by Run
Deep Learning (DL) Framework의 동작방식을 비교하는 용어. DL Model의 구축과 실행이 어떻게 이루어지는지로 구분됨. Define and Run DL Model을 구축 (= Computational Graph)이 먼저 이루어지고, 이후 input tesnsor를 정
ds31x.tistory.com
4. 객체 + 상태 중심 vs 함수 + 데이터 중심
4.1 TensorFlow / PyTorch의 전통적 방식
TensorFlow 2.x와 PyTorch는 다음과 같은 사고방식을 따른다.
- 모델은 객체(object)
- 모델의 파라미터와 상태는 객체 내부에 존재
- 연산은 객체의 메서드 호출로 수행
y = model(x)
이때 사용자는 모델 내부에 어떤 파라미터와 상태가 존재하는지를 명시적으로 전달하지 않으며,
연산 과정에서 객체 내부 상태가 암묵적으로 참조되거나 변경될 수 있다
4.2 JAX의 방식
JAX는 모델과 파라미터를 명확히 분리하여 표현한다.
y = f(params, x)
f: 입력과 파라미터를 받아 출력을 계산하는 순수 함수(pure function)params: 명시적으로 분리된 데이터로 모델의 파라미터를 다룸- 상태 변경은 새로운 데이터를 반환하는 방식으로 표현
JAX에서는
- 모델의 구조는 함수로,
- 모델의 상태는 데이터로 분리하여 다룬다.
즉, JAX는 “함수 + 데이터” 중심 시스템이다.
이는 다음의 수학적 표현에 대응됨:
$$
y = f(\theta, x)
$$
5. 자동 미분과 함수 변환
JAX의 자동 미분은 다음과 같이 이해할 수 있음:
from jax import grad
df = grad(f)
grad(f)의 결과는 미분을 계산하는 새로운 함수- 미분 과정에서 필요한 계산 그래프는 함수 호출 시 내부적으로 생성됨
따라서,
- JAX에서 미분은 “그래프를 만든다”가 아니라
- “미분된 함수를 생성”
PyTorch에서의 AutoGrad는 다음을 참고:
6. JAX의 위치: 딥러닝 프레임워크인가?
흔히, DL 프레임워크로 TensorFlow, PyTorch와 함께 JAX가 꼽히는 경우가 많음.
(물론 2개만 고를 때는 빠짐)
엄밀히 말하면 JAX는
- 딥러닝 프레임워크라기보다
- 미분 가능한 수치 계산을 위한 기반 시스템이다.
그래서 실제 신경망 학습에서는 보통 다음과 같은 라이브러리를 함께 사용한다.
- Flax / Haiku : 신경망 구조 정의
- Optax : 옵티마이저
- NumPyro : 확률 프로그래밍
7. 공식 튜토리얼 및 학습 자료
7.1 공식 사이트
JAX 공식 문서
JAX GitHub
https://github.com/google/jax
7.2 튜토리얼들
간략하게만 살펴봤을 뿐, 본인도 다 보진 못한 상태임...
JAX 101 / Quickstart
다음은 비공식 번역 정리된 곳임:
https://rtd-tutorial-ybeen.readthedocs.io/en/latest/JAX101/index.html?utm_source=chatgpt.com
Autodiff Cookbook (자동 미분 예제 모음)
https://kolonist26-jax-kr.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html?utm_source=chatgpt.com
8. JAX 의 주요 개념
하도 자주 나와서 정리만 해보고 어떤 녀석인지만 파악한 상태임...
8.1 NumPy 에서 JAX NumPy 로 확장:
jax.numpy사용- 배열은 immutable이라는 점 이해
8.2 grad : Gradient 함수
- 미분된 함수를 반환하는 함수
- Gradient를 계산할 수 있게 해주는 핵심 기능.
8.3 jit
- 실제 넘겨진 function이 언제 컴파일되는지
- 첫 호출과 이후 호출의 차이
8.4 vmap
- for-loop 없는 벡터화 사고 방식
8.5 간단한 모델 직접 작성
f(params, x)형태 유지- 옵티마이저는
Optax사용
8.6 Flax / Haiku로 확장
- 신경망 구조 추상화 학습
같이 보면 좋은 자료들
2025.02.28 - [CE] - [DL] GPU Acceleration 기술 소개
[DL] GPU Acceleration 기술 소개
1. 대표적 GPU 가속 기술현재 가장 많이 사용되고 있는 GPU기반의 가속기술은 다음과 같음:Compute Unified Device Architecture (CUDA): NVIDIAMetal Performance Shaders (MPS): AppleRadeon Open Compute (ROCm): AMDDirect Machine Learn
ds31x.tistory.com
2024.07.08 - [ML] - [ML] History of ML (작성중)
[ML] History of ML (작성중)
19401943 : ANN (Artificial Neural Network) 시작McCulloch, PittsA logical calculus of ideas immanent in nervous activity1949 : Weighting 변화를 통한 학습Donald Olding HebbThe Organization of Behavior: A Neuropsychological Theory19501956 : AI (Artif
ds31x.tistory.com
https://dsaint31.me/mkdocs_site/ML/ch08/reverse_mode_autodiff/
BME
Reverse-Mode Autodiff (Auto-Differentiation) Reverse-mode autodiff 는 TensorFlow, PyTorch 등에서 gradient를 구하는 back-propagation 수행에서 필요한 미분(differentiation)을 구하기 위해 사용되는 auto differentiation의 한 기법
dsaint31.me
'ML' 카테고리의 다른 글
| Deployment 가능한 HF Custom (Vision) Model 만들기 (0) | 2025.12.18 |
|---|---|
| torchvision.datasets.CocoDetection 간단 소개. (0) | 2025.12.16 |
| Object Detection 태스크에 대한 모델 평가를 COCO API로 하기 (0) | 2025.12.16 |
| pycocotools COCO API 기초 (0) | 2025.12.16 |
| MS COCO (Microsoft의 Common Object in Context) Dataset (1) | 2025.12.16 |