# PyTorch stack 메서드 완벽 가이드
## 1. torch.stack()이란?
`torch.stack()`은 여러 개의 텐서를 **새로운 차원**으로 쌓아서 하나의 텐서로 만드는 메서드입니다.
### 핵심 개념
- **입력**: 같은 크기의 텐서들의 리스트
- **출력**: 차원이 하나 증가한 텐서
- **특징**: 새로운 차원을 추가하면서 결합
## 2. 기본 예시로 이해하기
### 2.1 1차원 텐서 쌓기
```python
import torch
# 1차원 텐서 3개
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.tensor([7, 8, 9])
# stack으로 쌓기
result = torch.stack([a, b, c])
print(result)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(f"입력 shape: {a.shape}") # torch.Size([3])
print(f"출력 shape: {result.shape}") # torch.Size([3, 3])
```
### 2.2 시각적 이해
```
원본 텐서들:
a = [1, 2, 3]
b = [4, 5, 6]
c = [7, 8, 9]
stack 후:
[[1, 2, 3], ← a
[4, 5, 6], ← b
[7, 8, 9]] ← c
```
## 3. 배치 생성 함수에서의 stack 상세 분석
### 3.1 실제 데이터로 단계별 설명
```python
# 가정: dataset은 긴 텍스트의 토큰 ID 리스트
dataset = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100,
110, 120, 130, 140, 150, ..., 1000]
# 랜덤 인덱스와 block_size
idx = [234, 567, 123, 890]
block_size = 8
```
### 3.2 각 슬라이싱 결과
```python
# 첫 번째 슬라이싱
dataset[234:242] = dataset[234:234+8]
# 결과: [2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347]
# 두 번째 슬라이싱
dataset[567:575] = dataset[567:567+8]
# 결과: [5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677]
# 세 번째 슬라이싱
dataset[123:131] = dataset[123:123+8]
# 결과: [1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237]
# 네 번째 슬라이싱
dataset[890:898] = dataset[890:890+8]
# 결과: [8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907]
```
### 3.3 리스트 컴프리헨션 풀어서 보기
```python
# 원래 코드
x = torch.stack([dataset[index:index+block_size] for index in idx])
# 풀어서 쓰면
슬라이스_리스트 = []
for index in idx: # idx = [234, 567, 123, 890]
슬라이스 = dataset[index:index+block_size]
슬라이스_리스트.append(슬라이스)
x = torch.stack(슬라이스_리스트)
```
### 3.4 stack 동작 과정
```python
# stack 전: 4개의 1차원 텐서 (각각 길이 8)
슬라이스_리스트 = [
[2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347], # shape: (8,)
[5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677], # shape: (8,)
[1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237], # shape: (8,)
[8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907] # shape: (8,)
]
# stack 후: 2차원 텐서
x = torch.stack(슬라이스_리스트)
# 결과 shape: (4, 8)
# [[2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347],
# [5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677],
# [1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237],
# [8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907]]
```
## 4. stack vs cat 비교
### 4.1 torch.stack - 새 차원 추가
```python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# stack: 새로운 차원 추가
stacked = torch.stack([a, b])
print(stacked)
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(stacked.shape) # torch.Size([2, 3])
```
### 4.2 torch.cat - 기존 차원에서 연결
```python
# cat: 기존 차원에서 이어붙이기
catted = torch.cat([a, b])
print(catted)
# tensor([1, 2, 3, 4, 5, 6])
print(catted.shape) # torch.Size([6])
```
### 4.3 언제 무엇을 사용?
- **stack**: 여러 샘플을 배치로 만들 때
- **cat**: 시퀀스를 이어붙일 때
## 5. dim 파라미터 활용
### 5.1 다양한 차원에서 stack
```python
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# dim=0 (기본값): 첫 번째 차원에 쌓기
stack0 = torch.stack([a, b], dim=0)
print(stack0.shape) # torch.Size([2, 2, 2])
# [[[1, 2], [3, 4]],
# [[5, 6], [7, 8]]]
# dim=1: 두 번째 차원에 쌓기
stack1 = torch.stack([a, b], dim=1)
print(stack1.shape) # torch.Size([2, 2, 2])
# [[[1, 2], [5, 6]],
# [[3, 4], [7, 8]]]
# dim=2: 세 번째 차원에 쌓기
stack2 = torch.stack([a, b], dim=2)
print(stack2.shape) # torch.Size([2, 2, 2])
# [[[1, 5], [2, 6]],
# [[3, 7], [4, 8]]]
```
## 6. 배치 처리에서 stack의 중요성
### 6.1 왜 stack을 사용하는가?
```python
# 개별 처리 (비효율적)
for i in range(4):
output = model(dataset[idx[i]:idx[i]+8])
# 4번의 개별 연산
# 배치 처리 (효율적)
batch = torch.stack([dataset[i:i+8] for i in idx])
output = model(batch) # 한 번의 연산으로 4개 처리!
```
### 6.2 GPU 병렬 처리 활용
```python
# GPU는 행렬 연산에 최적화
# (4, 8) 크기의 배치를 한 번에 처리
# 개별 처리보다 약 4배 빠름
```
## 7. 실전 팁과 주의사항
### 7.1 크기가 같아야 함
```python
# 에러 발생 - 크기가 다름
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5]) # 크기 다름!
# torch.stack([a, b]) # RuntimeError!
```
### 7.2 메모리 효율적 사용
```python
# 큰 데이터셋에서는 generator 사용
def batch_generator(dataset, indices, block_size):
for idx in indices:
yield dataset[idx:idx+block_size]
# 메모리 효율적
batch = torch.stack(list(batch_generator(dataset, idx, block_size)))
```
## 8. 요약
`torch.stack()`의 핵심:
1. **새로운 차원 추가**: (8,) → (4, 8)
2. **배치 생성**: 여러 샘플을 하나로 묶기
3. **GPU 효율성**: 병렬 처리로 속도 향상
4. **크기 일치 필요**: 모든 텐서가 같은 shape여야 함
배치 생성 함수에서는 여러 위치의 텍스트 조각을 하나의 배치로 만들어 효율적으로 학습할 수 있게 합니다.
## 관련 참고 자료
- [[GPT 배치 생성 함수 완벽 가이드]]
- [[PyTorch 텐서와 torch.tensor() 완전 가이드]]
- [[PyTorch 랜덤 시드와 텐서 shape 완벽 이해]]