본문 바로가기
Python

[PyTorch] flattenning - Tensor's methods

by ds31x 2025. 3. 13.

PyTorch에서 Multi-dimensional tensor 를 1차원(1D)으로 변환을 지원하는 여러 방법(주로 method.)이 있음.

  • 단순히 view를 제공하는지 아니면 복사본인지에 대한 이해가 필요함.

NumPy의 ndarray의 경우는 다음을 참고:

2024.09.09 - [Python] - [NumPy] ravel() 메서드 with flatten() 메서드

 

[NumPy] ravel() 메서드 with flatten() 메서드

NumPy의 ndarray.ravel() 메서드는다차원 배열을 1차원 배열로 평탄화(flatten)하는 데 사용됨.기본적으로 이 메서드는 원본 배열의 데이터에 대한 뷰(view)를 반환: 즉 복사본을 생성하지 않고 메모리를

ds31x.tistory.com


1. tensor.view(-1)

  • 가장 빠른 방법이지만 제약이 있음
  • tensor가 반드시 contiguous(연속적)해야 함
  • 원본 tensor와 메모리를 공유하므로 수정 시 원본도 변경됨
  • tensor.is_contiguous()로 연속성 확인 후 사용할 것: contiguous가 아니면 예외 발생.
In [167]: a = torch.arange(24).reshape(2,3,4)

In [168]: a_t = a.transpose(1,2)

In [169]: a.is_contiguous()
Out[169]: True

In [170]: a_t.is_contiguous()
Out[170]: False

In [171]: a.view(-1)
Out[171]:
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23])

In [172]: a_t.view(-1)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[172], line 1
----> 1 a_t.view(-1)

RuntimeError: view size is not compatible with input tensor's size and stride 
(at least one dimension spans across two contiguous subspaces). 
Use .reshape(...) instead.

2024.03.22 - [Python] - [DL] PyTorch: view, data, and detach

 

[DL] PyTorch: view, data, and detach

PyTorch에서 tensor.view()와 tensor.data를 제공하며 이들은 다음과 같은 용도로 사용됨. tensor.view() tensor.view(*shape) 메서드는 tensor 인스턴스의 dimension을 수정하는데 사용됨. 이 메서드는 새로운 shape를 가

ds31x.tistory.com


2. tensor.flatten()

  • 모든 차원을 1D로 펼침
  • 연속적인 텐서에서는 view를 반환(메모리 공유)
  • 비연속적 텐서는 자동으로 연속적 버전으로 변환하고 수행 (복사본 생성)
  • 부분 평탄화 지원함.: flatten(start_dim, end_dim)

NumPy의 ndarray에서는 flatten이
항상 복사본을 생성하기 때문에 자주 헷갈린다.
tensor.flatten().clone() 의 형태로 사용하면
NumPy의 경우와 항상 같음.

In [178]: a.flatten()[0]=99

In [179]: a
Out[179]:
tensor([[[99,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
  • 이 경우, a에 대한 view에 불과하므로 메모리를 공유하여 갱신이 이루어짐.
  • a_t도 해당 요소가 같이 갱신됨.

하지만, a_t에 대해 flatten을 수행하면 복사가 이루어지므로 갱신이 안된다.

  • transpose 는 contiguous가 아니기 때문임.
In [180]: a_t.flatten()[0] = 77

In [181]: a_t
Out[181]:
tensor([[[99,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[12, 16, 20],
         [13, 17, 21],
         [14, 18, 22],
         [15, 19, 23]]])

 

부분 평탄화의 예는 다음과 같음

In [184]: a[0,0,0] = 0

In [185]: a
Out[185]:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [186]: a.flatten(start_dim=1) #부분 평탄화. index0 차원 이외의 차원이 flattening.
Out[186]:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

3. tensor.reshape(-1)

  • reshape는 사실 텐서의 shape를 변경하는 데에 많이 이용됨.
    • 여기서는 바꿀 차원을 -1 로 주어서 1D로 펼치도록 사용함.
  • reshape의 동작방식에 따라,
    • 가능하면 view 반환,
    • 불가능하면 복사본 생성
  • 즉, 비연속적 텐서에서도 사용 가능
  • 내부적으로 필요시 contiguous() 호출 (복사본 생성)
  • 부분 평탄화를 지원하지 않는다는 점을 제외하면 flatten과 같은 동작임.

2024.03.15 - [Python] - [DL] Tensor: dtype 변경(casting) 및 shape 변경.

 

[DL] Tensor: dtype 변경(casting) 및 shape 변경.

Tensor를 추상화하고 있는 class로는numpy.array: numpy의 ndarraytorch.tensortensorflow.constant: (or tensorflow.Variable)이 있음. 이들은 Python의 sequence types과 달리 일반적으로 다음과 같은 특징을 지님.데이터들이

ds31x.tistory.com

 


4. tensor.ravel()

  • NumPy 호환성을 위한 메서드
  • 비연속 텐서에서도 사용 가능.
  • 내부적으로 reshape(-1) 호출
  • 항상 완전 평탄화만 가능(부분 평탄화 불가)

5. tensor.contiguous().view(-1)

  • 비연속적 텐서를 명시적으로 연속적으로 만든 후 view 적용 (복사본에서 작업)
  • 동작이 예측 가능하여 디버깅이 쉬움

선택 가이드

  • 성능이 중요하고 텐서가 이미 연속적이면: view(-1)
  • 부분 평탄화가 필요하면: flatten(start_dim, end_dim)
  • 일반적인 상황에서 안전하게: reshape(-1)
  • 항상 독립적인 복사본이 필요하면: flatten().clone()

참고 - torch.nn.Flatten

Module로 ANN의 layer에서 평활화를 담당하는 torch.nn.Flatten 클래스는

  • 첫번째 차원을 batch로 삼기 때문에
  • 해당 차원을 남기고 나머지 차원에서 flatten이 수행됨.

다음의 예를 참고할것

In [191]: from torch.nn import Flatten

In [192]: a
Out[192]:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [193]: b = Flatten()(a)

In [194]: b
Out[194]:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

같이 보면 좋은 자료

2024.03.22 - [Python] - [DL] PyTorch: view, data, and detach

 

[DL] PyTorch: view, data, and detach

PyTorch에서 tensor.view()와 tensor.data를 제공하며 이들은 다음과 같은 용도로 사용됨. tensor.view() tensor.view(*shape) 메서드는 tensor 인스턴스의 dimension을 수정하는데 사용됨. 이 메서드는 새로운 shape를 가

ds31x.tistory.com

 

'Python' 카테고리의 다른 글

[PyTorch] dtype 단축메서드로 바꾸기  (0) 2025.03.14
[PyTorch] in-place 연산이란?  (0) 2025.03.13
[PyTorch] 생성 및 초기화, 기본 조작  (0) 2025.03.13
[Ex] PySide6  (0) 2025.03.11
[Py] dis 모듈 - Python  (1) 2025.03.11