
PyTorch에서 TPU를 사용하려면 torch_xla를 사용함.
참고로 현재
torch-xla의
최신 stable version은 2.9.0임.
2024.03.21 - [Python] - [DL] PyTorch: TPU사용하기 - XLA
[DL] PyTorch: TPU사용하기 - XLA
https://github.com/pytorch/xla GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)Enabling PyTorch on XLA Devices (e.g. Google TPU). Contribute to pytorch/xla development by creating an account on GitHub.github.com 다음 문서도 참
ds31x.tistory.com
PyTorch/XLA 2.4 이상에서는 TPU device를 다음처럼 얻을 수 있음.
import torch_xla
device = torch_xla.device()
print("XLA device:", device)
이후에는 CUDA나 MPS를 사용할 때와 마찬가지로, 모델과 tensor를 해당 device로 보내어 사용하면 됨:
model = model.to(device)
x = x.to(device)
y = y.to(device)
단, 차이점이 없는 건 아님.
PyTorch/XLA는 CPU/GPU backend와 달리 lazy execution 방식을 사용함.
- 즉, 연산이 호출되는 즉시 실행되는 것이 아니라,
- XLA graph로 누적되었다가 synchronization 지점에서 실제 실행되는 방식임.
따라서 PyTorch/XLA 2.4 이상에서는 step boundary에서 다음을 사용함:
torch_xla.sync()
참고: step boundary
step boundary는 하나의 training step이 논리적으로 끝나고 다음 training step으로 넘어가는 경계 지점을 의미함.
예를 들어 일반적인 training loop에서는 다음 흐름이 하나의 step임.
loss = model(x)
loss.backward()
optimizer.step()
optimizer.zero_grad()
따라서 step boundary는 보통 optimizer.step()까지 끝난 뒤, 다음 batch 학습으로 넘어가기 직전의 지점임.
Single TPU device 사용 시
단일 TPU device를 사용하는 경우에는
torch_xla.device()로 XLA device를 얻고,- 모델과 tensor를 해당 device로 이동시킴.
PyTorch/XLA는 lazy execution 방식을 사용하므로,
- 연산이 호출되는 즉시 실행되는 것이 아니라
- XLA graph로 누적되었다가 synchronization 지점에서 실제 실행되는 방식임.
따라서 학습 loop의 step boundary에서 torch_xla.sync()를 호출함.
import torch
import torch_xla
def train_fn():
# 1) XLA device 가져오기
device = torch_xla.device()
# 2) model을 XLA device로 이동
model = build_model().to(device)
optimizer = build_optimizer(model)
criterion = build_criterion()
# 3) 학습 loop
for x, y in train_loader:
# batch를 XLA device로 이동
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
# 단일 TPU device에서는 일반 PyTorch처럼 optimizer.step()을 호출함.
optimizer.step()
# PyTorch/XLA는 lazy execution 방식이므로,
# step boundary에서 누적된 XLA graph가 실행되도록 synchronization을 수행함.
torch_xla.sync()
train_fn()
주의할 점은 다음과 같음:
torch_xla.device()는 실제로 사용할 XLA device를 반환함.- 모델과 입력 tensor는 같은 XLA device 위에 있어야 함.
- PyTorch/XLA는 lazy execution 방식이므로 step boundary에서
torch_xla.sync()를 호출함. - 단일 TPU device에서는
optimizer.step()뒤에torch_xla.sync()를 호출하는 형태로 작성할 수 있음.
Multiple TPU device 사용 시
주의
현재 Colab TPU runtime에서는
torch_xla.launch(train_fn)가 여러 XLA replica를 초기화하려고 하면,
notebook에서 전달된 TPU worker address 수가 기대값과 맞지 않아
train_fn에 들어가기 전에 PJRT (Pluggable JIT Runtime) 초기화 단계에서 실패함 (2026.6).
이 경우 Colab notebook에서는xmp.spawn(train_fn, nprocs=1)로
단일 XLA replica 실행을 명시하거나,
multi-replica 학습은 torch_xla.launch()가 정상 동작하는 실행 환경에서 수행해야 함.
여러 TPU device를 사용하는 경우에는
- 특정 device 하나를 직접 고르는 방식이 아니라,
- XLA replica별로 학습 함수를 실행해야 함.
PyTorch/XLA 2.4 이상에서는 torch_xla.launch()를 사용하여 XLA replica별 학습 함수를 실행할 수 있음.
import torch_xla
import torch_xla.core.xla_model as xm
def train_fn(index):
# 각 replica 내부에서 현재 replica에 할당된 XLA device를 얻음.
# torch_xla.device()를 train_fn 밖에서 한 번 호출해 공유하면 안 됨.
device = torch_xla.device()
model = build_model().to(device)
optimizer = build_optimizer(model)
criterion = build_criterion()
for x, y in train_loader:
# 현재 replica의 XLA device로 batch 이동
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
# multiple TPU device에서는 replica 간 gradient synchronization이 필요함.
# 따라서 일반 optimizer.step()만 사용하지 않고,
# XLA용 optimizer step을 사용함.
#
# barrier=True는 step boundary에서 XLA synchronization까지 수행하도록 함.
xm.optimizer_step(optimizer, barrier=True)
if __name__ == "__main__":
# torch_xla.launch()가 XLA replica별로 train_fn(index)를 실행함.
# 이 호출 전에 torch_xla.device(), tensor.to("xla"), torch_xla.sync() 등
# XLA runtime을 초기화하는 코드를 먼저 실행하면 안 됨.
torch_xla.launch(train_fn)
optimizer.step()- 일반 PyTorch optimizer update만 수행함.
xm.optimizer_step(optimizer)- XLA 환경에서 optimizer update를 수행함.
- multiple replica 환경에서는 gradient synchronization까지 처리함.
xm.optimizer_step(optimizer, barrier=True)- XLA optimizer update를 수행한 뒤,
- step boundary에서 XLA graph가 실행되도록 synchronization까지 수행함.
- barrier=True는 xm.optimizer_step()에서 optimizer update 이후 모든 XLA device가 해당 step을 끝낼 때까지 동기화하여, 다음 연산으로 넘어가기 전에 step boundary를 명확히 만드는 옵션임.
일반적인 synchronization object로서의
barrier는
여러 thread/process가
특정 지점에 모두 도달할 때까지
각 실행 흐름을 대기시킨 뒤,
모두 도달하면 동시에 다음 단계로 진행시키는 동기화 장치임.
단, barrier=True가 항상 필요한 것은 아님.
MpDeviceLoader나 ParallelLoader처럼
batch loading 과정에서 step boundary를 만들어주는 XLA data loader를 사용하는 경우에는xm.optimizer_step(optimizer)처럼 barrier=True 없이 쓰는 예제도 많음.
PyTorch/XLA 문서에서도 multi-device 예제에서는 ParallelLoader가 barrier를 만들어주므로 xm.optimizer_step(optimizer)에 별도 barrier가 필요하지 않다고 설명함.
반대로 직접 for x, y in train_loader: 형태로 일반 PyTorch DataLoader를 사용하고 있다면,
step boundary를 명확히 하기 위해 xm.optimizer_step(optimizer, barrier=True)를 사용하는 편이 안전함.
PyTorch/XLA 문서에서 이를 통해 optimizer step에서 barrier를 넣으면 CPU와 XLA device를 명시적으로 synchronize한다고 설명함.
multiple TPU device 사용 시 주의할 점은 다음과 같음.
torch_xla.device()를 replica 밖에서 한 번 호출하고 이를 공유하면 안 됨.- 각 replica 내부에서
torch_xla.device()를 호출해야 해당 replica에 할당된 XLA device를 사용하게 됨. - multiple TPU device에서는 replica 간 gradient synchronization이 필요하므로 일반
optimizer.step()만 사용하면 안 됨. xm.optimizer_step(optimizer, barrier=True)를 사용하여 gradient synchronization, optimizer update, XLA synchronization을 처리함.
다시한번 강조하지만, torch_xla.launch()를 호출하기 전에 XLA device를 미리 생성하거나 XLA tensor를 만들면 안 됨.
요약
torch-xla 2.4 이상에서는 TPU 사용 시 다음 API를 중심으로 기억하면 됨.
torch_xla.device()- 현재 실행 단위에서 사용할 XLA device를 얻음.
- 단일 TPU에서는 일반적으로 한 번 얻어서 model과 tensor를 해당 device로 이동시킴.
- multiple TPU에서는 각 replica 내부에서 호출해야 함.
torch_xla.sync()- PyTorch/XLA의 lazy execution으로 누적된 XLA graph를 실제 실행시키는 synchronization 지점임.
- 단일 TPU 학습 loop에서는 보통
optimizer.step()이후 step boundary에서 호출함. - 기존의
xm.mark_step()을 대체하는 최신 방식임.
torch_xla.launch()- multiple TPU device 사용 시 XLA replica별로 학습 함수를 실행하기 위한 API임.
torch_xla.launch()호출 전에torch_xla.device(),tensor.to("xla"),torch_xla.sync()처럼 XLA runtime을 초기화하는 코드를 먼저 실행하면 안 됨.- 현재 Colab TPU notebook 환경에서는 PJRT 초기화 단계에서 worker address mismatch가 발생할 수 있으므로, 이 경우 단일 replica 실행은
xmp.spawn(train_fn, nprocs=1)로 명시하고, multi-replica 학습은torch_xla.launch()가 정상 동작하는 실행 환경에서 수행해야 함.
단일 TPU와 multiple TPU의 가장 큰 차이는 optimizer step 처리임 .
단일 TPU에서는 다음 구조를 사용.
optimizer.step()
torch_xla.sync()
단, multiple TPU에서는 replica 간 gradient synchronization이 필요하므로 다음 구조를 사용함:
xm.optimizer_step(optimizer, barrier=True)
정리하면,
- TPU에서도 model과 tensor를 device로 보내는 방식은 CUDA/MPS와 유사하지만,
- PyTorch/XLA는 lazy execution 기반이므로 step boundary를 명시해야 함.
- 또한 multiple TPU에서는 각 replica가 자기 XLA device를 내부에서 얻고,
- gradient synchronization을 포함한 XLA용 optimizer step을 사용해야 함.
같이 보면 좋은 자료들
https://dev-discuss.pytorch.org/t/pytorch-xla-2-4-dev-update/2356
PyTorch/XLA 2.4 dev update
PyTorch/XLA 2.4 Dev Update Hey I am here to give a late update for the PyTorch/XLA 2.3 release. Similar to my previous update, you can check our release note for detailed updates. I am going to highlight some of the new features and share how I think about
dev-discuss.pytorch.org
https://gist.github.com/dsaint31x/8290dab796f1dc2382ef7245abaa32d7#file-dl_torch_tpu_xla-ipynb
dl_torch_tpu_xla.ipynb
dl_torch_tpu_xla.ipynb. GitHub Gist: instantly share code, notes, and snippets.
gist.github.com
2024.03.21 - [Python] - [DL] PyTorch: TPU사용하기 - XLA
[DL] PyTorch: TPU사용하기 - XLA
https://github.com/pytorch/xla GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)Enabling PyTorch on XLA Devices (e.g. Google TPU). Contribute to pytorch/xla development by creating an account on GitHub.github.com 다음 문서도 참
ds31x.tistory.com
https://dsaint31.me/mkdocs_site/CE/colab/gpu/
BME
Colab: GPU 사용하기 Colab에서는 주로 CUDA 기반의 GPU 가속을 지원 런타임 → 런타임 유형 변경 → 하드웨어 가속기를 GPU로 변경 유의사항 – GPU는 최대 12시간 실행을 지원 12시간 실행 이후에는 런타
dsaint31.me
'ML' 카테고리의 다른 글
| [ML] Dataset: California Housing Dataset (0) | 2026.06.26 |
|---|---|
| Autograd : In-place 연산 (0) | 2026.06.23 |
| 2026 정리 (0) | 2026.06.10 |
| Hugging Face Trainer Callback과 JSONL 기반 Curve Logger (0) | 2026.06.04 |
| [ML] linear_model.SGDRegressor (0) | 2026.04.28 |