본문 바로가기
목차
ML

torch_xla 간단 사용법: TPU로 PyTorch 사용하기

by ds31x 2026. 6. 23.
728x90
반응형

PyTorch에서 TPU를 사용하려면 torch_xla를 사용함.

참고로 현재 torch-xla
최신 stable version은 2.9.0임.

 

2024.03.21 - [Python] - [DL] PyTorch: TPU사용하기 - XLA

 

[DL] PyTorch: TPU사용하기 - XLA

https://github.com/pytorch/xla GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)Enabling PyTorch on XLA Devices (e.g. Google TPU). Contribute to pytorch/xla development by creating an account on GitHub.github.com 다음 문서도 참

ds31x.tistory.com

 

PyTorch/XLA 2.4 이상에서는 TPU device를 다음처럼 얻을 수 있음.

import torch_xla

device = torch_xla.device()
print("XLA device:", device)

 

이후에는 CUDA나 MPS를 사용할 때와 마찬가지로, 모델과 tensor를 해당 device로 보내어 사용하면 됨:

model = model.to(device)

x = x.to(device)
y = y.to(device)

 

단, 차이점이 없는 건 아님.

 

PyTorch/XLA는 CPU/GPU backend와 달리 lazy execution 방식을 사용함.

  • 즉, 연산이 호출되는 즉시 실행되는 것이 아니라,
  • XLA graph로 누적되었다가 synchronization 지점에서 실제 실행되는 방식임.

따라서 PyTorch/XLA 2.4 이상에서는 step boundary에서 다음을 사용함:

torch_xla.sync()

참고: step boundary

step boundary는 하나의 training step이 논리적으로 끝나고 다음 training step으로 넘어가는 경계 지점을 의미함.

예를 들어 일반적인 training loop에서는 다음 흐름이 하나의 step임.

loss = model(x)
loss.backward()
optimizer.step()
optimizer.zero_grad()
 

따라서 step boundary는 보통 optimizer.step()까지 끝난 뒤, 다음 batch 학습으로 넘어가기 직전의 지점임.


Single TPU device 사용 시

단일 TPU device를 사용하는 경우에는

  • torch_xla.device()로 XLA device를 얻고,
  • 모델과 tensor를 해당 device로 이동시킴.

PyTorch/XLA는 lazy execution 방식을 사용하므로,

  • 연산이 호출되는 즉시 실행되는 것이 아니라
  • XLA graph로 누적되었다가 synchronization 지점에서 실제 실행되는 방식임.

따라서 학습 loop의 step boundary에서 torch_xla.sync()를 호출함.

import torch
import torch_xla

def train_fn():
    # 1) XLA device 가져오기
    device = torch_xla.device()

    # 2) model을 XLA device로 이동
    model = build_model().to(device)

    optimizer = build_optimizer(model)
    criterion = build_criterion()

    # 3) 학습 loop
    for x, y in train_loader:
        # batch를 XLA device로 이동
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        pred = model(x)
        loss = criterion(pred, y)

        loss.backward()

        # 단일 TPU device에서는 일반 PyTorch처럼 optimizer.step()을 호출함.
        optimizer.step()

        # PyTorch/XLA는 lazy execution 방식이므로,
        # step boundary에서 누적된 XLA graph가 실행되도록 synchronization을 수행함.
        torch_xla.sync()

train_fn()

 

주의할 점은 다음과 같음:

  • torch_xla.device()는 실제로 사용할 XLA device를 반환함.
  • 모델과 입력 tensor는 같은 XLA device 위에 있어야 함.
  • PyTorch/XLA는 lazy execution 방식이므로 step boundary에서 torch_xla.sync()를 호출함.
  • 단일 TPU device에서는 optimizer.step() 뒤에 torch_xla.sync()를 호출하는 형태로 작성할 수 있음.

Multiple TPU device 사용 시

주의

현재 Colab TPU runtime에서는
torch_xla.launch(train_fn)가 여러 XLA replica를 초기화하려고 하면,
notebook에서 전달된 TPU worker address 수가 기대값과 맞지 않아
train_fn에 들어가기 전에 PJRT (Pluggable JIT Runtime) 초기화 단계에서 실패함 (2026.6).


이 경우 Colab notebook에서는 xmp.spawn(train_fn, nprocs=1)
단일 XLA replica 실행을 명시하거나,
multi-replica 학습은 torch_xla.launch()가 정상 동작하는 실행 환경에서 수행해야 함.

 

여러 TPU device를 사용하는 경우에는

  • 특정 device 하나를 직접 고르는 방식이 아니라,
  • XLA replica별로 학습 함수를 실행해야 함.

PyTorch/XLA 2.4 이상에서는 torch_xla.launch()를 사용하여 XLA replica별 학습 함수를 실행할 수 있음.

import torch_xla
import torch_xla.core.xla_model as xm

def train_fn(index):
    # 각 replica 내부에서 현재 replica에 할당된 XLA device를 얻음.
    # torch_xla.device()를 train_fn 밖에서 한 번 호출해 공유하면 안 됨.
    device = torch_xla.device()

    model = build_model().to(device)
    optimizer = build_optimizer(model)
    criterion = build_criterion()

    for x, y in train_loader:
        # 현재 replica의 XLA device로 batch 이동
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        pred = model(x)
        loss = criterion(pred, y)

        loss.backward()

        # multiple TPU device에서는 replica 간 gradient synchronization이 필요함.
        # 따라서 일반 optimizer.step()만 사용하지 않고,
        # XLA용 optimizer step을 사용함.
        #
        # barrier=True는 step boundary에서 XLA synchronization까지 수행하도록 함.
        xm.optimizer_step(optimizer, barrier=True)

if __name__ == "__main__":
    # torch_xla.launch()가 XLA replica별로 train_fn(index)를 실행함.
    # 이 호출 전에 torch_xla.device(), tensor.to("xla"), torch_xla.sync() 등
    # XLA runtime을 초기화하는 코드를 먼저 실행하면 안 됨.
    torch_xla.launch(train_fn)
  • optimizer.step()
    • 일반 PyTorch optimizer update만 수행함.
  • xm.optimizer_step(optimizer)
    • XLA 환경에서 optimizer update를 수행함.
    • multiple replica 환경에서는 gradient synchronization까지 처리함.
  • xm.optimizer_step(optimizer, barrier=True)
    • XLA optimizer update를 수행한 뒤,
    • step boundary에서 XLA graph가 실행되도록 synchronization까지 수행함.
    • barrier=Truexm.optimizer_step()에서 optimizer update 이후 모든 XLA device가 해당 step을 끝낼 때까지 동기화하여, 다음 연산으로 넘어가기 전에 step boundary를 명확히 만드는 옵션임.
일반적인 synchronization object로서의
barrier는
여러 thread/process가
특정 지점에 모두 도달할 때까지
각 실행 흐름을 대기시킨 뒤,
모두 도달하면 동시에 다음 단계로 진행시키는 동기화 장치임.

 

 

단, barrier=True가 항상 필요한 것은 아님.

더보기


MpDeviceLoaderParallelLoader처럼
batch loading 과정에서 step boundary를 만들어주는 XLA data loader를 사용하는 경우에는xm.optimizer_step(optimizer)처럼 barrier=True 없이 쓰는 예제도 많음.
PyTorch/XLA 문서에서도 multi-device 예제에서는 ParallelLoader가 barrier를 만들어주므로 xm.optimizer_step(optimizer)에 별도 barrier가 필요하지 않다고 설명함.

반대로 직접 for x, y in train_loader: 형태로 일반 PyTorch DataLoader를 사용하고 있다면,

step boundary를 명확히 하기 위해 xm.optimizer_step(optimizer, barrier=True)를 사용하는 편이 안전함.

PyTorch/XLA 문서에서 이를 통해  optimizer step에서 barrier를 넣으면 CPU와 XLA device를 명시적으로 synchronize한다고 설명함.

 

multiple TPU device 사용 시 주의할 점은 다음과 같음.

  • torch_xla.device()를 replica 밖에서 한 번 호출하고 이를 공유하면 안 됨.
  • 각 replica 내부에서 torch_xla.device()를 호출해야 해당 replica에 할당된 XLA device를 사용하게 됨.
  • multiple TPU device에서는 replica 간 gradient synchronization이 필요하므로 일반 optimizer.step()만 사용하면 안 됨.
  • xm.optimizer_step(optimizer, barrier=True)를 사용하여 gradient synchronization, optimizer update, XLA synchronization을 처리함.

다시한번 강조하지만, torch_xla.launch()를 호출하기 전에 XLA device를 미리 생성하거나 XLA tensor를 만들면 안 됨.


요약

torch-xla 2.4 이상에서는 TPU 사용 시 다음 API를 중심으로 기억하면 됨.

  • torch_xla.device()
    • 현재 실행 단위에서 사용할 XLA device를 얻음.
    • 단일 TPU에서는 일반적으로 한 번 얻어서 model과 tensor를 해당 device로 이동시킴.
    • multiple TPU에서는 각 replica 내부에서 호출해야 함.
  • torch_xla.sync()
    • PyTorch/XLA의 lazy execution으로 누적된 XLA graph를 실제 실행시키는 synchronization 지점임.
    • 단일 TPU 학습 loop에서는 보통 optimizer.step() 이후 step boundary에서 호출함.
    • 기존의 xm.mark_step()을 대체하는 최신 방식임.
  • torch_xla.launch()
    • multiple TPU device 사용 시 XLA replica별로 학습 함수를 실행하기 위한 API임.
    • torch_xla.launch() 호출 전에 torch_xla.device(), tensor.to("xla"), torch_xla.sync()처럼 XLA runtime을 초기화하는 코드를 먼저 실행하면 안 됨.
    • 현재 Colab TPU notebook 환경에서는 PJRT 초기화 단계에서 worker address mismatch가 발생할 수 있으므로, 이 경우 단일 replica 실행은 xmp.spawn(train_fn, nprocs=1)로 명시하고, multi-replica 학습은 torch_xla.launch()가 정상 동작하는 실행 환경에서 수행해야 함.

단일 TPU와 multiple TPU의 가장 큰 차이는 optimizer step 처리임 .

 

단일 TPU에서는 다음 구조를 사용.

optimizer.step()
torch_xla.sync()

 

단, multiple TPU에서는 replica 간 gradient synchronization이 필요하므로 다음 구조를 사용함:

xm.optimizer_step(optimizer, barrier=True)

 

정리하면,

  • TPU에서도 model과 tensor를 device로 보내는 방식은 CUDA/MPS와 유사하지만,
  • PyTorch/XLA는 lazy execution 기반이므로 step boundary를 명시해야 함.
  • 또한 multiple TPU에서는 각 replica가 자기 XLA device를 내부에서 얻고,
  • gradient synchronization을 포함한 XLA용 optimizer step을 사용해야 함.

같이 보면 좋은 자료들

https://dev-discuss.pytorch.org/t/pytorch-xla-2-4-dev-update/2356

 

PyTorch/XLA 2.4 dev update

PyTorch/XLA 2.4 Dev Update Hey I am here to give a late update for the PyTorch/XLA 2.3 release. Similar to my previous update, you can check our release note for detailed updates. I am going to highlight some of the new features and share how I think about

dev-discuss.pytorch.org

https://gist.github.com/dsaint31x/8290dab796f1dc2382ef7245abaa32d7#file-dl_torch_tpu_xla-ipynb

 

dl_torch_tpu_xla.ipynb

dl_torch_tpu_xla.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

2024.03.21 - [Python] - [DL] PyTorch: TPU사용하기 - XLA

 

[DL] PyTorch: TPU사용하기 - XLA

https://github.com/pytorch/xla GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU)Enabling PyTorch on XLA Devices (e.g. Google TPU). Contribute to pytorch/xla development by creating an account on GitHub.github.com 다음 문서도 참

ds31x.tistory.com

https://dsaint31.me/mkdocs_site/CE/colab/gpu/

 

BME

Colab: GPU 사용하기 Colab에서는 주로 CUDA 기반의 GPU 가속을 지원 런타임 → 런타임 유형 변경 → 하드웨어 가속기를 GPU로 변경 유의사항 – GPU는 최대 12시간 실행을 지원 12시간 실행 이후에는 런타

dsaint31.me


 

728x90

'ML' 카테고리의 다른 글

[ML] Dataset: California Housing Dataset  (0) 2026.06.26
Autograd : In-place 연산  (0) 2026.06.23
2026 정리  (0) 2026.06.10
Hugging Face Trainer Callback과 JSONL 기반 Curve Logger  (0) 2026.06.04
[ML] linear_model.SGDRegressor  (0) 2026.04.28