# GPT 배치 이중 for 루프 완벽 해부 ## 이중 for 루프 구조 ```python for size in range(batch_size): # 외부 루프: 배치 선택 (0~3) for t in range(block_size): # 내부 루프: 시간 스텝 (0~7) context = example_x[size, :t+1] target = example_y[size, t] print(f"input : {context}, target : {target}") print("-----------------------") print("-----------------------") ``` ## 실제 데이터 설정 ```python # batch_function이 반환한 실제 데이터 example_x = torch.tensor([ [1764, 2555, 0, 1236, 2248, 0, 2017, 1976], # 배치 0 [ 0, 1966, 2157, 0, 1951, 2062, 0, 2548], # 배치 1 [ 0, 1304, 1485, 1586, 0, 1907, 2450, 0], # 배치 2 [ 3, 2, 6, 5, 1, 0, 5, 3] # 배치 3 ]) # shape: (4, 8) example_y = torch.tensor([ [2555, 0, 1236, 2248, 0, 2017, 1976, 2546], # 배치 0 [1966, 2157, 0, 1951, 2062, 0, 2548, 2289], # 배치 1 [1304, 1485, 1586, 0, 1907, 2450, 0, 2480], # 배치 2 [ 2, 6, 5, 1, 0, 5, 3, 5] # 배치 3 ]) # shape: (4, 8) ``` ## 상세 실행 과정 ### 첫 번째 배치 (size=0) 완전 분석 ```python # 외부 루프: size = 0 시작 print("=== 배치 0 처리 시작 ===") # 내부 루프 8번 실행 # t=0: 첫 번째 토큰으로 두 번째 토큰 예측 context = example_x[0, :1] # example_x[0, 0:1] = [1764] target = example_y[0, 0] # 2555 print(f"t=0: input=[1764], target=2555") # t=1: 첫 두 토큰으로 세 번째 토큰 예측 context = example_x[0, :2] # example_x[0, 0:2] = [1764, 2555] target = example_y[0, 1] # 0 print(f"t=1: input=[1764, 2555], target=0") # t=2: 첫 세 토큰으로 네 번째 토큰 예측 context = example_x[0, :3] # example_x[0, 0:3] = [1764, 2555, 0] target = example_y[0, 2] # 1236 print(f"t=2: input=[1764, 2555, 0], target=1236") # t=3: 첫 네 토큰으로 다섯 번째 토큰 예측 context = example_x[0, :4] # example_x[0, 0:4] = [1764, 2555, 0, 1236] target = example_y[0, 3] # 2248 print(f"t=3: input=[1764, 2555, 0, 1236], target=2248") # t=4: 첫 다섯 토큰으로 여섯 번째 토큰 예측 context = example_x[0, :5] # example_x[0, 0:5] = [1764, 2555, 0, 1236, 2248] target = example_y[0, 4] # 0 print(f"t=4: input=[1764, 2555, 0, 1236, 2248], target=0") # t=5: 첫 여섯 토큰으로 일곱 번째 토큰 예측 context = example_x[0, :6] # example_x[0, 0:6] = [1764, 2555, 0, 1236, 2248, 0] target = example_y[0, 5] # 2017 print(f"t=5: input=[1764, 2555, 0, 1236, 2248, 0], target=2017") # t=6: 첫 일곱 토큰으로 여덟 번째 토큰 예측 context = example_x[0, :7] # example_x[0, 0:7] = [1764, 2555, 0, 1236, 2248, 0, 2017] target = example_y[0, 6] # 1976 print(f"t=6: input=[1764, 2555, 0, 1236, 2248, 0, 2017], target=1976") # t=7: 첫 여덟 토큰으로 아홉 번째 토큰 예측 context = example_x[0, :8] # example_x[0, 0:8] = [1764, 2555, 0, 1236, 2248, 0, 2017, 1976] target = example_y[0, 7] # 2546 print(f"t=7: input=[1764, 2555, 0, 1236, 2248, 0, 2017, 1976], target=2546") print("=== 배치 0 처리 완료 ===") # 내부 루프 종료, 외부 루프 다음 iteration으로 ``` ### 두 번째 배치 (size=1) 완전 분석 ```python # 외부 루프: size = 1 시작 print("=== 배치 1 처리 시작 ===") # 내부 루프 8번 실행 # t=0 context = example_x[1, :1] # example_x[1, 0:1] = [0] target = example_y[1, 0] # 1966 print(f"t=0: input=[0], target=1966") # t=1 context = example_x[1, :2] # example_x[1, 0:2] = [0, 1966] target = example_y[1, 1] # 2157 print(f"t=1: input=[0, 1966], target=2157") # t=2 context = example_x[1, :3] # example_x[1, 0:3] = [0, 1966, 2157] target = example_y[1, 2] # 0 print(f"t=2: input=[0, 1966, 2157], target=0") # t=3 context = example_x[1, :4] # example_x[1, 0:4] = [0, 1966, 2157, 0] target = example_y[1, 3] # 1951 print(f"t=3: input=[0, 1966, 2157, 0], target=1951") # t=4 context = example_x[1, :5] # example_x[1, 0:5] = [0, 1966, 2157, 0, 1951] target = example_y[1, 4] # 2062 print(f"t=4: input=[0, 1966, 2157, 0, 1951], target=2062") # t=5 context = example_x[1, :6] # example_x[1, 0:6] = [0, 1966, 2157, 0, 1951, 2062] target = example_y[1, 5] # 0 print(f"t=5: input=[0, 1966, 2157, 0, 1951, 2062], target=0") # t=6 context = example_x[1, :7] # example_x[1, 0:7] = [0, 1966, 2157, 0, 1951, 2062, 0] target = example_y[1, 6] # 2548 print(f"t=6: input=[0, 1966, 2157, 0, 1951, 2062, 0], target=2548") # t=7 context = example_x[1, :8] # example_x[1, 0:8] = [0, 1966, 2157, 0, 1951, 2062, 0, 2548] target = example_y[1, 7] # 2289 print(f"t=7: input=[0, 1966, 2157, 0, 1951, 2062, 0, 2548], target=2289") print("=== 배치 1 처리 완료 ===") ``` ### 세 번째 배치 (size=2) 완전 분석 ```python # 외부 루프: size = 2 시작 print("=== 배치 2 처리 시작 ===") # t=0 context = example_x[2, :1] # [0] target = example_y[2, 0] # 1304 print(f"t=0: input=[0], target=1304") # t=1 context = example_x[2, :2] # [0, 1304] target = example_y[2, 1] # 1485 print(f"t=1: input=[0, 1304], target=1485") # t=2 context = example_x[2, :3] # [0, 1304, 1485] target = example_y[2, 2] # 1586 print(f"t=2: input=[0, 1304, 1485], target=1586") # t=3 context = example_x[2, :4] # [0, 1304, 1485, 1586] target = example_y[2, 3] # 0 print(f"t=3: input=[0, 1304, 1485, 1586], target=0") # t=4 context = example_x[2, :5] # [0, 1304, 1485, 1586, 0] target = example_y[2, 4] # 1907 print(f"t=4: input=[0, 1304, 1485, 1586, 0], target=1907") # t=5 context = example_x[2, :6] # [0, 1304, 1485, 1586, 0, 1907] target = example_y[2, 5] # 2450 print(f"t=5: input=[0, 1304, 1485, 1586, 0, 1907], target=2450") # t=6 context = example_x[2, :7] # [0, 1304, 1485, 1586, 0, 1907, 2450] target = example_y[2, 6] # 0 print(f"t=6: input=[0, 1304, 1485, 1586, 0, 1907, 2450], target=0") # t=7 context = example_x[2, :8] # [0, 1304, 1485, 1586, 0, 1907, 2450, 0] target = example_y[2, 7] # 2480 print(f"t=7: input=[0, 1304, 1485, 1586, 0, 1907, 2450, 0], target=2480") print("=== 배치 2 처리 완료 ===") ``` ### 네 번째 배치 (size=3) 완전 분석 ```python # 외부 루프: size = 3 시작 print("=== 배치 3 처리 시작 ===") # t=0 context = example_x[3, :1] # [3] target = example_y[3, 0] # 2 print(f"t=0: input=[3], target=2") # t=1 context = example_x[3, :2] # [3, 2] target = example_y[3, 1] # 6 print(f"t=1: input=[3, 2], target=6") # t=2 context = example_x[3, :3] # [3, 2, 6] target = example_y[3, 2] # 5 print(f"t=2: input=[3, 2, 6], target=5") # t=3 context = example_x[3, :4] # [3, 2, 6, 5] target = example_y[3, 3] # 1 print(f"t=3: input=[3, 2, 6, 5], target=1") # t=4 context = example_x[3, :5] # [3, 2, 6, 5, 1] target = example_y[3, 4] # 0 print(f"t=4: input=[3, 2, 6, 5, 1], target=0") # t=5 context = example_x[3, :6] # [3, 2, 6, 5, 1, 0] target = example_y[3, 5] # 5 print(f"t=5: input=[3, 2, 6, 5, 1, 0], target=5") # t=6 context = example_x[3, :7] # [3, 2, 6, 5, 1, 0, 5] target = example_y[3, 6] # 3 print(f"t=6: input=[3, 2, 6, 5, 1, 0, 5], target=3") # t=7 context = example_x[3, :8] # [3, 2, 6, 5, 1, 0, 5, 3] target = example_y[3, 7] # 5 print(f"t=7: input=[3, 2, 6, 5, 1, 0, 5, 3], target=5") print("=== 배치 3 처리 완료 ===") print("=== 모든 배치 처리 완료 ===") ``` ## 전체 실행 순서 요약 ### 실행 흐름 ``` 총 32번 실행 (4 배치 × 8 시간스텝) 1. size=0, t=0 → input=[1764], target=2555 2. size=0, t=1 → input=[1764, 2555], target=0 3. size=0, t=2 → input=[1764, 2555, 0], target=1236 4. size=0, t=3 → input=[1764, 2555, 0, 1236], target=2248 5. size=0, t=4 → input=[1764, 2555, 0, 1236, 2248], target=0 6. size=0, t=5 → input=[1764, 2555, 0, 1236, 2248, 0], target=2017 7. size=0, t=6 → input=[1764, 2555, 0, 1236, 2248, 0, 2017], target=1976 8. size=0, t=7 → input=[1764, 2555, 0, 1236, 2248, 0, 2017, 1976], target=2546 ↓ 첫 번째 배치 완료, 두 번째 배치 시작 9. size=1, t=0 → input=[0], target=1966 10. size=1, t=1 → input=[0, 1966], target=2157 11. size=1, t=2 → input=[0, 1966, 2157], target=0 12. size=1, t=3 → input=[0, 1966, 2157, 0], target=1951 13. size=1, t=4 → input=[0, 1966, 2157, 0, 1951], target=2062 14. size=1, t=5 → input=[0, 1966, 2157, 0, 1951, 2062], target=0 15. size=1, t=6 → input=[0, 1966, 2157, 0, 1951, 2062, 0], target=2548 16. size=1, t=7 → input=[0, 1966, 2157, 0, 1951, 2062, 0, 2548], target=2289 ↓ 두 번째 배치 완료, 세 번째 배치 시작 17. size=2, t=0 → input=[0], target=1304 ... (8번 실행) 24. size=2, t=7 → input=[0, 1304, 1485, 1586, 0, 1907, 2450, 0], target=2480 ↓ 세 번째 배치 완료, 네 번째 배치 시작 25. size=3, t=0 → input=[3], target=2 ... (8번 실행) 32. size=3, t=7 → input=[3, 2, 6, 5, 1, 0, 5, 3], target=5 ↓ 모든 배치 완료 ``` ## 패턴 분석 ### 1. 외부 루프 (size) - **역할**: 4개 배치 중 어느 것을 처리할지 선택 - **범위**: 0, 1, 2, 3 - **특징**: 한 번에 하나씩, 순차적으로 처리 ### 2. 내부 루프 (t) - **역할**: 선택된 배치에서 시간 스텝 진행 - **범위**: 0, 1, 2, 3, 4, 5, 6, 7 - **특징**: 각 size마다 0~7까지 완전히 실행 ### 3. Context 길이 변화 ``` t=0: 길이 1 ([토큰1]) t=1: 길이 2 ([토큰1, 토큰2]) t=2: 길이 3 ([토큰1, 토큰2, 토큰3]) ... t=7: 길이 8 ([토큰1, 토큰2, ..., 토큰8]) ``` ### 4. Target 위치 ``` t=0: y[size, 0] (두 번째 토큰) t=1: y[size, 1] (세 번째 토큰) t=2: y[size, 2] (네 번째 토큰) ... t=7: y[size, 7] (아홉 번째 토큰) ``` ## 핵심 포인트 ### 완전 순차 실행 - size=0이 t=0~7까지 **완전히 끝나야** size=1이 시작 - 중간에 다른 배치로 넘어가지 않음 - 각 배치는 독립적으로 8개의 학습 샘플 생성 ### 자기회귀 학습의 핵심 - **점진적 문맥 확장**: 1토큰 → 2토큰 → ... → 8토큰 - **다음 토큰 예측**: 매 시점에서 바로 다음 토큰을 정답으로 사용 - **다양한 길이 학습**: 짧은 문맥부터 긴 문맥까지 모든 경우 학습 ### 최종 결과 - **32개 학습 샘플**: 4배치 × 8시간스텝 - **다양한 패턴**: 각기 다른 텍스트 조각에서 추출 - **효율적 배치 처리**: GPU가 동시에 처리할 수 있는 형태로 준비 ## 관련 참고 자료 - [[GPT 배치 생성 함수 한글 예시로 완벽 이해]] - [[PyTorch stack 메서드 완벽 가이드]]