본문 바로가기
카테고리 없음

[DL] PyTorch: nn.ModuleList, nn.ModuleDict, nn.Sequential

by ds31x 2024. 5. 17.

nn.ModuleList, nn.ModuleDict, nn.Sequential 사용법 및 유의사항

nn.ModuleList, nn.ModuleDict, nn.Sequential는 모델의 block (=submodule) 구성에 매우 유용한 클래스들임.


nn.Sequential

nn.Sequential은 여러 모듈을 순차적으로 실행할 수 있도록 하는 PyTorch 클래스임.

간단한 네트워크 구조를 만들 때 유용하며, 모듈을 정의한 순서대로 순차적으로 적용함.


사용법:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

print(model)

위 예제에서

nn.Sequentialnn.Linearnn.ReLU 모듈을 순차적으로 포함하고 있으며, 정의한 순서대로 순차적으로 적용함.

nn.Sequential을 사용하면 네트워크 구조를 직관적으로 정의할 수 있어 초보자에게 매우 유용함.

https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html

 

Sequential — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 


nn.ModuleList

nn.ModuleList는 PyTorch의 nn.Module을 리스트처럼 다루기 위한 클래스임.

 

일반적인 Python 리스트와 유사하게 동작하지만,

리스트에 포함된 모든 모듈이 PyTorch의 모델 구성 요소로 인식되어 올바르게 등록되고 관리됨.

이는 모델 파라미터가 자동으로 추적되며, to(), cuda(), cpu()와 같은 메서드를 사용할 수 있게 함.


사용법:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = MyModel()
print(model)

위 예제에서

nn.ModuleList는 리스트 안에 nn.Linearnn.ReLU 모듈을 포함하고 있음.

forward 메서드에서는 리스트에 포함된 각 모듈을 순차적으로 적용함.

https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html

 

ModuleList — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 


nn.ModuleDict

nn.ModuleDict는 PyTorch의 nn.Module을 딕셔너리처럼 다루기 위한 클래스임.

 

일반적인 Python 딕셔너리와 유사하게 동작하지만,

딕셔너리에 포함된 모든 모듈이 PyTorch의 모델 구성 요소로 인식되어 올바르게 등록되고 관리됨.

이는 모델 파라미터가 자동으로 추적되며, to(), cuda(), cpu()와 같은 메서드를 사용할 수 있게 함.


사용법:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(10, 20),
            'relu': nn.ReLU(),
            'fc2': nn.Linear(20, 10)
        })

    def forward(self, x):
        x = self.layers['fc1'](x)
        x = self.layers['relu'](x)
        x = self.layers['fc2'](x)
        return x

model = MyModel()
print(model)

위 예제에서

nn.ModuleDict는 키-값 쌍으로 nn.Linearnn.ReLU 모듈을 포함하고 있음.

forward 메서드에서는 딕셔너리의 키를 사용하여 각 모듈을 순차적으로 적용함.

https://pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html

 

ModuleDict — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 


일반 list나 dict를 사용할 때 발생하는 문제

일반 listdict

nn.ModuleListnn.ModuleDict 대신 사용하면 여러 가지 문제가 발생할 수 있음.

 

문제점은 list 또는 dict 내의 모듈들의
파라메터 자동 추적 이 이루어지지 않는 다는 점임.

 

PyTorch는

nn.ModuleListnn.ModuleDict에 포함된 모듈의 파라미터를 자동으로 추적함.

이를 통해 모델의 모든 파라미터가 자동으로 등록되고, parameters() 메서드을 통해 optim 등에서 쉽게 접근할 수 있음.

 

앞서의 예제 코드들에서 self.layers에 일반 list를 사용하면, 모델의 파라미터가 자동으로 추적되지 않음.

따라서 model.parameters()를 호출해도 파라미터를 얻을 수 없음.

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = [
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10)
        ]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = MyModel()
print(model)
print("Model parameters:", list(model.parameters()))  # 파라미터가 추적되지 않음

 

이는 GPU/CPU 이동할 때에도 문제가 됨

nn.ModuleListnn.ModuleDict를 사용하면,

model.to('cuda')model.cuda()와 같은 메서드를 호출할 때 포함된 모든 모듈이 자동으로 GPU로 이동하지만,

listdict를 사용하면 이 동작이 자동으로 이루어지지 않음.


결론

  • nn.Sequential: 여러 모듈을 순차적으로 실행할 수 있도록 함. 간단한 네트워크 구조를 만들 때 유용함.
  • nn.ModuleList: PyTorch 모듈을 리스트 형태로 관리함. 반복문을 통해 순차적으로 모듈을 적용할 때 유용함.
  • nn.ModuleDict: PyTorch 모듈을 딕셔너리 형태로 관리함. 키를 사용하여 모듈을 접근할 때 유용함.

일반 listdict를 사용할 경우 모델의 파라미터가 자동으로 추적되지 않으며, GPU/CPU 이동이 자동으로 이루어지지 않음.

이는 parameters 메서드에서 확인이 안되기 때문임.

이러한 이유로 PyTorch 모델을 구성할 때는 일반 listdict가 아닌 nn.ModuleList, nn.ModuleDict, nn.Sequential를 사용해야 한다.

달리 애기하면 top-level attribute 로 sub-module을 추가하는 경우 외에는 nn.ModuleList, nn.ModuleDict, nn.Sequential를 사용해야 한다.


https://gist.github.com/dsaint31x/f34fcb64f0955a764592edd51628953d

 

dl_module_list_dict_sequential.ipynb

dl_module_list_dict_sequential.ipynb. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com