# PyTorch stack 메서드 완벽 가이드 ## 1. torch.stack()이란? `torch.stack()`은 여러 개의 텐서를 **새로운 차원**으로 쌓아서 하나의 텐서로 만드는 메서드입니다. ### 핵심 개념 - **입력**: 같은 크기의 텐서들의 리스트 - **출력**: 차원이 하나 증가한 텐서 - **특징**: 새로운 차원을 추가하면서 결합 ## 2. 기본 예시로 이해하기 ### 2.1 1차원 텐서 쌓기 ```python import torch # 1차원 텐서 3개 a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) c = torch.tensor([7, 8, 9]) # stack으로 쌓기 result = torch.stack([a, b, c]) print(result) # tensor([[1, 2, 3], # [4, 5, 6], # [7, 8, 9]]) print(f"입력 shape: {a.shape}") # torch.Size([3]) print(f"출력 shape: {result.shape}") # torch.Size([3, 3]) ``` ### 2.2 시각적 이해 ``` 원본 텐서들: a = [1, 2, 3] b = [4, 5, 6] c = [7, 8, 9] stack 후: [[1, 2, 3], ← a [4, 5, 6], ← b [7, 8, 9]] ← c ``` ## 3. 배치 생성 함수에서의 stack 상세 분석 ### 3.1 실제 데이터로 단계별 설명 ```python # 가정: dataset은 긴 텍스트의 토큰 ID 리스트 dataset = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, ..., 1000] # 랜덤 인덱스와 block_size idx = [234, 567, 123, 890] block_size = 8 ``` ### 3.2 각 슬라이싱 결과 ```python # 첫 번째 슬라이싱 dataset[234:242] = dataset[234:234+8] # 결과: [2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347] # 두 번째 슬라이싱 dataset[567:575] = dataset[567:567+8] # 결과: [5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677] # 세 번째 슬라이싱 dataset[123:131] = dataset[123:123+8] # 결과: [1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237] # 네 번째 슬라이싱 dataset[890:898] = dataset[890:890+8] # 결과: [8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907] ``` ### 3.3 리스트 컴프리헨션 풀어서 보기 ```python # 원래 코드 x = torch.stack([dataset[index:index+block_size] for index in idx]) # 풀어서 쓰면 슬라이스_리스트 = [] for index in idx: # idx = [234, 567, 123, 890] 슬라이스 = dataset[index:index+block_size] 슬라이스_리스트.append(슬라이스) x = torch.stack(슬라이스_리스트) ``` ### 3.4 stack 동작 과정 ```python # stack 전: 4개의 1차원 텐서 (각각 길이 8) 슬라이스_리스트 = [ [2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347], # shape: (8,) [5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677], # shape: (8,) [1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237], # shape: (8,) [8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907] # shape: (8,) ] # stack 후: 2차원 텐서 x = torch.stack(슬라이스_리스트) # 결과 shape: (4, 8) # [[2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347], # [5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677], # [1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237], # [8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907]] ``` ## 4. stack vs cat 비교 ### 4.1 torch.stack - 새 차원 추가 ```python a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # stack: 새로운 차원 추가 stacked = torch.stack([a, b]) print(stacked) # tensor([[1, 2, 3], # [4, 5, 6]]) print(stacked.shape) # torch.Size([2, 3]) ``` ### 4.2 torch.cat - 기존 차원에서 연결 ```python # cat: 기존 차원에서 이어붙이기 catted = torch.cat([a, b]) print(catted) # tensor([1, 2, 3, 4, 5, 6]) print(catted.shape) # torch.Size([6]) ``` ### 4.3 언제 무엇을 사용? - **stack**: 여러 샘플을 배치로 만들 때 - **cat**: 시퀀스를 이어붙일 때 ## 5. dim 파라미터 활용 ### 5.1 다양한 차원에서 stack ```python a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5, 6], [7, 8]]) # dim=0 (기본값): 첫 번째 차원에 쌓기 stack0 = torch.stack([a, b], dim=0) print(stack0.shape) # torch.Size([2, 2, 2]) # [[[1, 2], [3, 4]], # [[5, 6], [7, 8]]] # dim=1: 두 번째 차원에 쌓기 stack1 = torch.stack([a, b], dim=1) print(stack1.shape) # torch.Size([2, 2, 2]) # [[[1, 2], [5, 6]], # [[3, 4], [7, 8]]] # dim=2: 세 번째 차원에 쌓기 stack2 = torch.stack([a, b], dim=2) print(stack2.shape) # torch.Size([2, 2, 2]) # [[[1, 5], [2, 6]], # [[3, 7], [4, 8]]] ``` ## 6. 배치 처리에서 stack의 중요성 ### 6.1 왜 stack을 사용하는가? ```python # 개별 처리 (비효율적) for i in range(4): output = model(dataset[idx[i]:idx[i]+8]) # 4번의 개별 연산 # 배치 처리 (효율적) batch = torch.stack([dataset[i:i+8] for i in idx]) output = model(batch) # 한 번의 연산으로 4개 처리! ``` ### 6.2 GPU 병렬 처리 활용 ```python # GPU는 행렬 연산에 최적화 # (4, 8) 크기의 배치를 한 번에 처리 # 개별 처리보다 약 4배 빠름 ``` ## 7. 실전 팁과 주의사항 ### 7.1 크기가 같아야 함 ```python # 에러 발생 - 크기가 다름 a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5]) # 크기 다름! # torch.stack([a, b]) # RuntimeError! ``` ### 7.2 메모리 효율적 사용 ```python # 큰 데이터셋에서는 generator 사용 def batch_generator(dataset, indices, block_size): for idx in indices: yield dataset[idx:idx+block_size] # 메모리 효율적 batch = torch.stack(list(batch_generator(dataset, idx, block_size))) ``` ## 8. 요약 `torch.stack()`의 핵심: 1. **새로운 차원 추가**: (8,) → (4, 8) 2. **배치 생성**: 여러 샘플을 하나로 묶기 3. **GPU 효율성**: 병렬 처리로 속도 향상 4. **크기 일치 필요**: 모든 텐서가 같은 shape여야 함 배치 생성 함수에서는 여러 위치의 텍스트 조각을 하나의 배치로 만들어 효율적으로 학습할 수 있게 합니다. ## 관련 참고 자료 - [[GPT 배치 생성 함수 완벽 가이드]] - [[PyTorch 텐서와 torch.tensor() 완전 가이드]] - [[PyTorch 랜덤 시드와 텐서 shape 완벽 이해]]