PyTorch에서 Multi-dimensional tensor 를 1차원(1D)으로 변환을 지원하는 여러 방법(주로 method.)이 있음.
- 단순히 view를 제공하는지 아니면 복사본인지에 대한 이해가 필요함.
NumPy의 ndarray의 경우는 다음을 참고:
2024.09.09 - [Python] - [NumPy] ravel() 메서드 with flatten() 메서드
[NumPy] ravel() 메서드 with flatten() 메서드
NumPy의 ndarray.ravel() 메서드는다차원 배열을 1차원 배열로 평탄화(flatten)하는 데 사용됨.기본적으로 이 메서드는 원본 배열의 데이터에 대한 뷰(view)를 반환: 즉 복사본을 생성하지 않고 메모리를
ds31x.tistory.com
1. tensor.view(-1)
- 가장 빠른 방법이지만 제약이 있음
- tensor가 반드시 contiguous(연속적)해야 함
- 원본 tensor와 메모리를 공유하므로 수정 시 원본도 변경됨
tensor.is_contiguous()
로 연속성 확인 후 사용할 것: contiguous가 아니면 예외 발생.
In [167]: a = torch.arange(24).reshape(2,3,4)
In [168]: a_t = a.transpose(1,2)
In [169]: a.is_contiguous()
Out[169]: True
In [170]: a_t.is_contiguous()
Out[170]: False
In [171]: a.view(-1)
Out[171]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23])
In [172]: a_t.view(-1)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[172], line 1
----> 1 a_t.view(-1)
RuntimeError: view size is not compatible with input tensor's size and stride
(at least one dimension spans across two contiguous subspaces).
Use .reshape(...) instead.
2024.03.22 - [Python] - [DL] PyTorch: view, data, and detach
[DL] PyTorch: view, data, and detach
PyTorch에서 tensor.view()와 tensor.data를 제공하며 이들은 다음과 같은 용도로 사용됨. tensor.view() tensor.view(*shape) 메서드는 tensor 인스턴스의 dimension을 수정하는데 사용됨. 이 메서드는 새로운 shape를 가
ds31x.tistory.com
2. tensor.flatten()
- 모든 차원을 1D로 펼침
- 연속적인 텐서에서는 view를 반환(메모리 공유)
- 비연속적 텐서는 자동으로 연속적 버전으로 변환하고 수행 (복사본 생성)
- 부분 평탄화 지원함.:
flatten(start_dim, end_dim)
NumPy의 ndarray에서는 flatten이
항상 복사본을 생성하기 때문에 자주 헷갈린다.tensor.flatten().clone()
의 형태로 사용하면
NumPy의 경우와 항상 같음.
In [178]: a.flatten()[0]=99
In [179]: a
Out[179]:
tensor([[[99, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
- 이 경우,
a
에 대한 view에 불과하므로 메모리를 공유하여 갱신이 이루어짐. a_t
도 해당 요소가 같이 갱신됨.
하지만, a_t
에 대해 flatten을 수행하면 복사가 이루어지므로 갱신이 안된다.
- transpose 는 contiguous가 아니기 때문임.
In [180]: a_t.flatten()[0] = 77
In [181]: a_t
Out[181]:
tensor([[[99, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]],
[[12, 16, 20],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23]]])
부분 평탄화의 예는 다음과 같음
In [184]: a[0,0,0] = 0
In [185]: a
Out[185]:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
In [186]: a.flatten(start_dim=1) #부분 평탄화. index0 차원 이외의 차원이 flattening.
Out[186]:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
3. tensor.reshape(-1)
reshape
는 사실 텐서의 shape를 변경하는 데에 많이 이용됨.- 여기서는 바꿀 차원을
-1
로 주어서 1D로 펼치도록 사용함.
- 여기서는 바꿀 차원을
- reshape의 동작방식에 따라,
- 가능하면 view 반환,
- 불가능하면 복사본 생성
- 즉, 비연속적 텐서에서도 사용 가능
- 내부적으로 필요시
contiguous()
호출 (복사본 생성) - 부분 평탄화를 지원하지 않는다는 점을 제외하면
flatten
과 같은 동작임.
2024.03.15 - [Python] - [DL] Tensor: dtype 변경(casting) 및 shape 변경.
[DL] Tensor: dtype 변경(casting) 및 shape 변경.
Tensor를 추상화하고 있는 class로는numpy.array: numpy의 ndarraytorch.tensortensorflow.constant: (or tensorflow.Variable)이 있음. 이들은 Python의 sequence types과 달리 일반적으로 다음과 같은 특징을 지님.데이터들이
ds31x.tistory.com
4. tensor.ravel()
- NumPy 호환성을 위한 메서드
- 비연속 텐서에서도 사용 가능.
- 내부적으로
reshape(-1)
호출 - 항상 완전 평탄화만 가능(부분 평탄화 불가)
5. tensor.contiguous().view(-1)
- 비연속적 텐서를 명시적으로 연속적으로 만든 후 view 적용 (복사본에서 작업)
- 동작이 예측 가능하여 디버깅이 쉬움
선택 가이드
- 성능이 중요하고 텐서가 이미 연속적이면:
view(-1)
- 부분 평탄화가 필요하면:
flatten(start_dim, end_dim)
- 일반적인 상황에서 안전하게:
reshape(-1)
- 항상 독립적인 복사본이 필요하면:
flatten().clone()
참고 - torch.nn.Flatten
Module로 ANN의 layer에서 평활화를 담당하는 torch.nn.Flatten 클래스는
- 첫번째 차원을 batch로 삼기 때문에
- 해당 차원을 남기고 나머지 차원에서 flatten이 수행됨.
다음의 예를 참고할것
In [191]: from torch.nn import Flatten
In [192]: a
Out[192]:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
In [193]: b = Flatten()(a)
In [194]: b
Out[194]:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
같이 보면 좋은 자료
2024.03.22 - [Python] - [DL] PyTorch: view, data, and detach
[DL] PyTorch: view, data, and detach
PyTorch에서 tensor.view()와 tensor.data를 제공하며 이들은 다음과 같은 용도로 사용됨. tensor.view() tensor.view(*shape) 메서드는 tensor 인스턴스의 dimension을 수정하는데 사용됨. 이 메서드는 새로운 shape를 가
ds31x.tistory.com
'Python' 카테고리의 다른 글
[PyTorch] dtype 단축메서드로 바꾸기 (0) | 2025.03.14 |
---|---|
[PyTorch] in-place 연산이란? (0) | 2025.03.13 |
[PyTorch] 생성 및 초기화, 기본 조작 (0) | 2025.03.13 |
[Ex] PySide6 (0) | 2025.03.11 |
[Py] dis 모듈 - Python (1) | 2025.03.11 |