Custom Model 만들기
1. top-level attribute로 sub-module추가.
가장 간단하게 sub-module (layer or activation)을 추가하는 방법임.
주의할 것은 top-level attribute가 아닌, list나 dictionary 인스턴스 로 싸고 있는 형태로 추가될 경우, parameters를 통해 제대로 찾지 못하며, 이는 optim 등에서 해당 파라메터나 sub-module의 파라메터를 추적하지 못하는 문제점이나 GPU로 이동시킬 때 해당 파라메터를 제대로 이동시키지 못하는 문제를 일으키게 된다.
이에 대한 부분은 다음 URL을 참고할 것.
2024.05.17 - [분류 전체보기] - [DL] PyTorch: nn.ModuleList, nn.ModuleDict, nn.Sequential
다음은 간단한 MLP를 만드는 예제로 각 layer를 Module을 상속한 클래스에서 instance top-level attribute 로 추가하고 있음.
import torch
from torch.nn import Module, init, Linear, Parameter, ReLU
from torch import optim
class DsANN (Module): #custom module
def __init__(self,
n_in_f, # input vector의 차원수.
n_out_f, # output vector의 차원수.
):
super().__init__() # required!
self.linear0 = Linear(n_in_f, 32)
self.relu0 = ReLU()
self.linear1 = Linear(32, 32)
self.relu1 = ReLU()
self.linear2 = Linear(32, n_out_f)
with torch.no_grad():
# 이 블록 내에서의 연산은 자동 미분(autograd)에서 제외
# linear0의 bias를 상수 0. 으로 초기화
init.constant_(self.linear0.bias, 0.)
# Xavier 초기화 방식으로 초기화
init.xavier_uniform_(self.linear0.weight)
def forward(self,x):
x = self.linear0(x)
x = self.relu0(x)
x = self.linear1(x)
x = self.relu1(x)
y = self.linear2(x)
return y
Multi-Layer Perceptron을 간단히 구현함.
- 초기화 부분은 autograd(자동미분) 그래프에 추가될 필요가 없어서 명시적으로 추가되지 않음을 나타냄.
2. add_module
로 sub-module 추가.
앞서, instance variable로 처리한 경우와 결과는 같음.,
for문 등의 loop로 구현할 때 사용되는 경우가 많음.
class DsANN (Module): #custom module
def __init__(self,
n_in_f, # input vector의 차원수.
n_out_f, # output vector의 차원수.
):
super().__init__() # required!
self.add_module('linear0', Linear(n_in_f, 32))
self.add_module('relu0', ReLU())
self.add_module('linear1', Linear(32, 32))
self.add_module('relu1', ReLU())
self.add_module('linear2', Linear(32, n_out_f))
with torch.no_grad():
# Module의 apply는 특정 함수를 자신의 submodules에 모두 적용 (재귀적).
self.apply(self.init_weight)
def forward(self,x):
for c in self.children():
x=c(x)
return x
# x = self.linear0(x)
# x = self.relu0(x)
# x = self.linear1(x)
# x = self.relu1(x)
# x = self.linear2(x)
# return x
@classmethod
def init_weight(cls, module):
if type(module) == torch.nn.Linear:
init.kaiming_uniform_(module.weight, mode='fan_in', nonlinearity='relu')
# init.ones_(module.weight)
init.constant_(module.bias, 0)
model = DsANN(1,1)
print(model)
3. Module의 메서드들.
3-1. Module의 상태 확인하기.
.parameters(recurse=True)
optimizer에게 학습을 통해 갱신되어야하는 model의 parameters를 넘겨줄 때 사용됨.
각, 정보를 볼 때는 .named_parameter()
를 이용하면, 이름을 같이 확인할 수 있음.
.named_buffers()
buffers는 역시 tensor
객체이나,
학습과정에서 갱신이 필요한 parameters와 달리
학습과정에서 변하지 않는 데이터를 저장하는데 사용된다.
.named_buffers()
는 모델 내에 정의된 모든 버퍼를 이름과 함께 dictionary 형식으로 반환함.
.children()
현재 모델이 가지고 있는 직접적인 sub-moduels에 대한 iterator를 반환.
.named_children()
의 경우엔 이름과 함께 dictionary 형식으로 반환한다.
바로 아래의 자식에만 접근함.
.modules()
현재 모델이 가지고 있는 모든 sub-module에 대해 재귀적으로 반환함.
다음의 예를 확인할 것.
class DoubleLinear(Module):
def __init__(self, n_in, n_out):
super().__init__()
tmp = [(n_in, n_out), (n_out, n_out)]
for idx, t in enumerate(tmp):
self.add_module(f'linear{idx}', Linear(*t))
self.add_module(f'relu{idx}', ReLU())
def forward(self,x):
for c in self.children():
x=c(x)
return x
class DsANN (Module): #custom module
def __init__(self,
n_in_f, # input vector의 차원수.
n_out_f, # output vector의 차원수.
):
super().__init__() # required!
self.add_module('module1', DoubleLinear(n_in_f,32))
self.add_module('module2', Linear(32, n_out_f))
with torch.no_grad():
self.apply(self.init_weight)
def forward(self,x):
for c in self.children():
x=c(x)
return x
@classmethod
def init_weight(cls, module):
if type(module) == torch.nn.Linear:
init.kaiming_uniform_(module.weight, mode='fan_in', nonlinearity='relu')
# init.ones_(module.weight)
init.constant_(module.bias, 0)
model = DsANN(1,1)
print(model)
여기서 children의 경우는 다음 코드로 확인 가능함.
# for idx, cl in enumerate(model.named_children()):
for idx, cl in enumerate(model.children()):
print(idx, cl)
결과는 다음과 같음.
0 DoubleLinear(
(linear0): Linear(in_features=1, out_features=32, bias=True)
(relu0): ReLU()
(linear1): Linear(in_features=32, out_features=32, bias=True)
(relu1): ReLU()
)
1 Linear(in_features=32, out_features=1, bias=True)
modules의 경우는 다음 코드로 확인 가능함.
# for idx, modu in enumerate(model.named_modules()):
for idx, modu in enumerate(model.modules()):
print (idx, modu)
결과는 다음과 같음.
0 DsANN(
(module1): DoubleLinear(
(linear0): Linear(in_features=1, out_features=32, bias=True)
(relu0): ReLU()
(linear1): Linear(in_features=32, out_features=32, bias=True)
(relu1): ReLU()
)
(module2): Linear(in_features=32, out_features=1, bias=True)
)
1 DoubleLinear(
(linear0): Linear(in_features=1, out_features=32, bias=True)
(relu0): ReLU()
(linear1): Linear(in_features=32, out_features=32, bias=True)
(relu1): ReLU()
)
2 Linear(in_features=1, out_features=32, bias=True)
3 ReLU()
4 Linear(in_features=32, out_features=32, bias=True)
5 ReLU()
6 Linear(in_features=32, out_features=1, bias=True)
3-2. Model의 상태 저장 및 로딩.
2024.05.16 - [분류 전체보기] - [DL] Torch: Save and Load Model
2024.05.16 - [분류 전체보기] - [DL] PyTorch: state_dict()
.state_dict(desitnation=None, prefix='', keep_vars=False)
현재 모델의 모든 상태 (paramerters and buffers)를 얻어냄.
collections.OrderedDict 로 반환.
keep_vars
는 기본값이 False로 buffers와 parameters 의 값만을 추출할지를 결정keep_vars=True
인 경우, 값 대신 tensor객체로 데이터 버퍼를 가지고 있는 dictionary가 반환됨.keep_vars=True
인 경우, 메모리 사용량이 커지고, 매우 느리고 복잡한 동작이 이루어지지만. 다음의 장점을 가짐.- 모델 디버깅: 모델 상태를 조사하고 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
- 모델 커스터마이징: 모델을 불러온 후 특정 매개 변수나 버퍼의 값을 변경해야 하는 경우 유용
- 모델 저장 및 불러오기 확장: 모델 저장 및 불러오기 프로세스를 확장하고 추가적인 정보를 저장해야 하는 경우 유용
- 하지만, PyTorch의 버전이 정확히 맞아야만 동작할 수 있는 등의 제한점을 가짐.
.load_state_dict(state_dict, strict=True)
현재 모델을 "모델상태를 저장한 state_dict
"를 이용하여 상태를 설정함.
strict=True
인 경우, 모든 attribute의 이름이 정확히 같아야만 복원됨 (attribute 이름이 key로 사용됨).
3-3. Device 지정.
.cuda(device=None)
현재의 모델을 gpu로 복사. (모델의 input tensor들도 gpu로 이동시키는 처리 필요.)
.cpu()
현재의 모델을 cpu로 복사. (모델의 input tensor들도 cpu로 이동시키는 처리 필요.)
https://gist.github.com/dsaint31x/ce49bfbfadc4f95b01684346bfbea76b
'Python' 카테고리의 다른 글
[Python] class 만들기. (0) | 2024.04.14 |
---|---|
[DL] Pandas 로 csv 읽기: read_csv (0) | 2024.04.13 |
[PyTorch] CustomANN Example: From Celsius to Fahrenheit (0) | 2024.04.12 |
[PyTorch] torch.nn.init (0) | 2024.04.11 |
[PyTorch] Dataset and DataLoader (0) | 2024.04.09 |