본문 바로가기
목차
ML

pytorch-torchinfo 란

by ds31x 2026. 4. 9.
728x90
반응형

1. torchinfo 란?

torchinfo는 PyTorch 모델의 구조를 표 형태로 요약해서 보여주는 도구 (formerly torch-summary).

 

주로 다음을 확인할 때 사용됨:

  • 각 layer의 출력 shape
  • parameter 수
  • trainable 여부
  • nested module 구조
  • 실제 forward()를 따라가며 형상이 어떻게 변하는지

공식 URL은 다음임:
https://github.com/tyleryep/torchinfo

 

GitHub - TylerYep/torchinfo: View model summaries in PyTorch!

View model summaries in PyTorch! Contribute to TylerYep/torchinfo development by creating an account on GitHub.

github.com


2. 기본 사용법

colab 등에서 사용하려면 설치가 필요함:

pip install torchinfo

 

설치 후 다음과 같이 사용함:

from torchinfo import summary
# summary(model, input_size=(배치크기, 채널수, 높이, 너비))
t = summary(model, input_size=(1, 3, 224, 224))
print(t)
  • input_size는 실제 forward()에 들어갈 텐서의 shape를 의미.
  • torchinfo는 이 정보를 바탕으로 내부적으로 모델을 한 번 forward()를 실행하여 각 계층의 출력 shape와 parameter 수를 계산함.
  • 제대로 된 각 계층의 shape와 parameter 수를 계산하기 위해선 반드시 input_size를 넘겨줘야 함.

torchinfo
이 모델이 실제 입력을 받았을 때
각 module이 어떤 출력 shape를 만드는가
를 빠르게 점검하는 도구임.


3. 기본 예제

간단한 MLP의 구조를 확인하는 예제임:

import torch
import torch.nn as nn
from torchinfo import summary

class TinyMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 4)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = TinyMLP()

# jupyter note의 경우 마지막 expression 은 출력되니
# 그냥 summary호출이 마지막에 놓아도 출력이 보임.
summary(
    model,
    input_size=(8, 16),
    col_names=["input_size", "output_size", "num_params", "trainable"],
)

 

결과는 다음과 같음:

  • 입력 (8, 16)은 batch size가 8, feature 수가 16이라는 뜻
  • fc1 출력은 (8, 32)
  • fc2 출력은 (8, 4)
  • Linear의 parameter 수가 예상과 맞는지 확인 가능

torchinfo가 단순히 모델 정의를 보여주는 것이 아니라, 입력이 통과한 뒤의 shape 변화를 보여준다 는 점을 유의할 것.


다음은 간단한 CNN 의 예제임:

import torch
import torch.nn as nn
from torchinfo import summary

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, padding=1),   # (B, 8, 32, 32)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 8, 16, 16)
            nn.Conv2d(8, 16, kernel_size=3, padding=1),  # (B, 16, 16, 16)
            nn.ReLU(),
            nn.MaxPool2d(2)                              # (B, 16, 8, 8)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 8 * 8, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = TinyCNN()

t = summary(
    model,
    input_size=(4, 3, 32, 32),
    depth=3,
    col_names=["output_size", "num_params","kernel_size"],
)
print(t)

  • Conv2d는 주로 채널 수를 바꾸는 역할
  • MaxPool2d는 공간 크기 32 -> 16 -> 8로 줄이는 역할
  • Flatten 전 최종 크기가 16 x 8 x 8인지 확인 가능
  • 마지막 Linear 입력 차원이 맞는지 확인 가능

CNN에서는
특히 pooling 횟수나 padding 설정을 잘못해서 마지막 Linear 입력 차원을 틀리는 경우가 많은데
이를 체크하는데 torchinfo가 매우 유용함.


4. summary 함수에서 자주 사용되는 옵션

4-1. depth

nested module을 얼마나 깊게 펼칠지 결정.

summary(model, input_size=(2, 3, 32, 32), depth=5)
  • 작게 주면 상위 block만 표시
  • 크게 주면 block 내부의 모듈(=nested module)까지 표시

4-2. col_names

출력할 column들을 지정.

summary(
    model,
    input_size=(2, 3, 32, 32),
    col_names=["output_size", "num_params", "kernel_size", "mult_adds"]
)

 

현재 지원되는 column은 다음과 같음:

  • "input_size" : 해당 layer로 들어가는 입력 tensor shape 표시.
  • "output_size" : 해당 layer의 출력 tensor shape 표시. 기본적으로 가장 자주 보는 항목.
  • "num_params" : 해당 layer의 parameter 개수 표시.
  • "params_percent" : 전체 parameter 중 해당 layer가 차지하는 비율 표시.
  • "kernel_size" : convolution 계열 layer 등의 kernel 크기 표시. 해당 정보가 없는 layer는 --처럼 보일 수 있음.
  • "groups" : grouped convolution 등에서의 group 수 표시. 일반 layer에서는 의미가 없을 수 있음.
  • "mult_adds" : 대략적인 multiply-add 연산량 표시. 연산량 비교할 때 유용함.
  • "trainable" : 해당 layer의 parameter가 학습 대상인지 표시. freeze 여부 확인에 유용함.

4-3. row_settings

변수명이나 depth 표현 방식을 더 잘 보이게 할 때 사용.

summary(
    model,
    input_size=(2, 3, 32, 32),
    depth=5,
    row_settings=["depth", "var_names"],
)
  • var_namesrow_settings에 추가하면
  • 해당 module 을 가리키는 변수명이 같이 출력되어 가독성이 향상가능함(변수명을 잘 지은 경우 한정임)

현재 지원되는 항목은 다음과 같음:

  • "ascii_only"
    • 트리 표시 문자를 유니코드 대신 ASCII 문자로 출력함.
    • 터미널이나 폰트 환경에 따라 트리 선이 깨질 때 유용함.
  • "depth"
    • 각 row에 depth 기반 계층 표시를 포함함.
    • 어떤 module이 상위 module 아래에 속하는지 트리 형태로 읽기 쉽게 해 줌.
    • 기본값.
  • "var_names"
    • module의 변수명까지 함께 보여줌.
    • 예를 들어 block1.conv1, block2.bn2 처럼 코드에서 붙인 이름을 읽기 쉽게 확인할 때 유용함.


5. torchinfo 에서 nested layer와 branch 구조 확인.

앞서 예제에서 확인했듯이
torchinfo는 model 안에 들어 있는 nested module 구조를 펼쳐서 보여주는 기능을 제공함.

 

예를 들어 다음과 같은 구조가 있으면

  • SmallResNet
    • stem
    • block1
      • conv1
      • bn1
      • relu
      • conv2
      • bn2
    • block2
      • conv1
      • bn1
      • relu
      • conv2
      • bn2

이 내부 구조를 depth 옵션에 따라 단계적으로 펼쳐서 볼 수 있음.

 

즉, 다음과 같은 확인이 가능합니다.

  • 어떤 상위 block 안에 어떤 하위 module이 들어 있는지
  • 각 하위 module이 어떤 output shape를 만드는지
  • custom block 내부가 의도대로 구성되어 있는지

다음은 torchinfo가 nested module 구조를 어떻게 펼쳐서 보여주는지 확인하는 예제임.

import torch
import torch.nn as nn
from torchinfo import summary

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity
        out = self.relu(out)
        return out

class SmallResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.block1 = ResidualBlock(16)
        self.block2 = ResidualBlock(16)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = SmallResNet()

print("=== depth=1 ===")
t = summary(
    model,
    input_size=(2, 3, 32, 32),
    depth=1,
    col_names=["output_size", "num_params"],
    row_settings=["depth","var_names"],
)
print(t)

print("\n=== depth=2 ===")
t = summary(
    model,
    input_size=(2, 3, 32, 32),
    depth=2,
    col_names=["output_size", "num_params"],
    row_settings=["depth","var_names"],
)
print(t)
)

위의 예제에서는 depth=1depth=2의 출력 차이를 통해 nested 구조 확인을 depth로 어떻게 조절할지를 보여줌:

  • depth=1에서는 stem, block1, block2, pool, fc 같은 상위 module 위주로 보임
  • depth=2에서는 block1, block2 내부의 conv1, bn1, relu, conv2, bn2까지 펼쳐져 보임

주의-torchinfo 는 계산 그래프 자체를 그리진 못함

주의할 점은
torchinfo는 계산 그래프를 화살표로 직접 그려주는 도구가 아니라는 점임.

 

다음은 torchinfo가 수행할 수 없음:

  • x + identity의 merge 지점을 그래프 화살표로 보여주기
  • torch.cat, +, * 같은 텐서 연산을 노드 수준으로 시각화하기
  • 전체 dataflow를 DAG처럼 그림으로 그리기

 

torchinfo가 제공하는 기능:

  • module 계층 구조 표시
  • nested block 내부 표시
  • 각 module의 output shape 표시
  • residual addition이 가능하도록 두 경로의 shape가 맞는지 점검

 

torchinfo가 직접 제공하지 않는 기능

  • 실제 계산 그래프 시각화
  • merge/add/cat 연산의 선 연결 그림
  • branch 간 텐서 흐름의 화살표 표시

때문에 앞서의 ResidualBlock 내의 residual connection 등은 확인이 어려움.

https://dsaint31.me/mkdocs_site/ML/ch14_cnn/resnet/#residual-block

 

BME

ResNet: Deep Residual Learning for Image Recognition (2015) ref. : ori. Deep Residual Learning for Image Recognition ILSVRC 2015년 우승 모델 (deep learning의 deep의 개념을 바꿈.) Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun top-5 error가

dsaint31.me

 

실제로 summary(...) 결과를 보면

  • stem, block1, block2, pool, fc 같은 nn.Module 단위 구조는 보이지만,
  • ResidualBlock.forward() 안에서 수행되는 다음 연산은 표에 직접 드러나지 않음.
identity = x
...
out = out + identity

이는 torchinfo가 다음과 같은 정보를 보여주는 도구이기 때문임:

  • 어떤 nn.Module이 호출되었는지
  • nn.Module의 출력 shape가 무엇인지
  • nn.Module의 parameter 수가 얼마인지

때문에 다음과 같은 정보는 직접 표시되지 않음:

  • out + identity 같은 텐서 간 덧셈 연산
  • torch.cat(...) 같은 병합 연산
  • branch가 갈라졌다가 다시 합쳐지는 계산 그래프의 화살표 연결

torchinfo
"연산 그래프 시각화 도구"라기보다
"module 구조 요약 도구" 임.


같이 보면 좋은 자료들

https://gist.github.com/ds31x/0ac5d0383eeca2218ac82e0ea2a884c6

 

dl_torchinfo_tutorial.ipynb

dl_torchinfo_tutorial.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

https://dsaint31.me/mkdocs_site/ML/ch14_cnn/resnet/#residual-block

 

BME

ResNet: Deep Residual Learning for Image Recognition (2015) ref. : ori. Deep Residual Learning for Image Recognition ILSVRC 2015년 우승 모델 (deep learning의 deep의 개념을 바꿈.) Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun top-5 error가

dsaint31.me


728x90