양자화된 옵티마이저 상태의 정체와 리셋 효과: LLM 사전학습에서의 동적 분석
본 논문은 저정밀도(예: BF16, FP8, FP4)로 저장된 EMA(지수 이동 평균) 옵티마이저 상태가 업데이트 시 양자화 오차로 인해 동일한 값으로 머무르는 ‘정체(stalling)’ 현상을 수학적으로 모델링한다. 정체 확률을 예측하는 이론을 바탕으로, 정체가 심해지는 시점을 파악하고 적절한 리셋 주기를 설계함으로써 저정밀도에서도 메모리 사용량을 크게 줄이면서 성능 손실을 최소화할 수 있음을 실험적으로 입증한다.
저자: Kristi Topollai, Anna Choromanska
본 논문은 대규모 언어 모델(LLM) 사전학습에서 메모리 사용량을 크게 줄이기 위해 옵티마이저 상태를 저정밀도(예: BF16, FP8, FP4)로 양자화하는 최근 트렌드에 주목한다. 옵티마이저는 AdamW, Lion, Shampoo 등 다양한 변형이 존재하지만, 이들 모두 첫 번째 모멘트(m)와 두 번째 모멘트(v)와 같은 지수 이동 평균(EMA) 상태를 유지한다. 저정밀도 저장은 메모리와 대역폭을 절감하지만, 양자화 과정에서 발생하는 ‘업데이트 정체(state‑update stalling)’ 현상이 옵티마이저의 동역학을 크게 왜곡한다는 점이 아직 충분히 이해되지 못했다.
**1. 정체 현상의 정의와 기하학적 해석**
EMA 재귀식 sₜ = β sₜ₋₁ + (1‑β)Δₜ를 고정밀 연산으로 수행한 뒤, 결과를 다시 저정밀 형식으로 양자화(Q)한다. 양자화는 인접한 표현 가능한 실수 사이의 최소 간격(ulp) uₜ를 만든다. 고정밀 업데이트 Δₜ가 uₜ/2보다 작으면 양자화 후 저장값이 변하지 않는다. 이를 ‘정체’라 정의하고, 정체가 발생하면 해당 스텝은 β = 1(즉, 완전 보존)과 동일하게 동작한다.
**2. 효과적 정밀도 비율(ρ̂) 도입**
각 형식마다 상대 간격 ε가 다르다(예: BF16 ε = 2⁻⁷, FP8(E4M3) ε = 2⁻³, FP4(E2M2) ε = 2⁻²). 평균 맨티사 \(\bar m\)를 고려해 \(\hat\rho = \frac{ε}{2(1‑β)\bar m}\)를 정의한다. 이 단일 스칼라 파라미터가 정체 확률을 완전히 결정한다.
**3. 정체 확률의 정량적 모델**
단일 좌표에 대해 gₜ를 정규분포 N(0,σ²)라 가정하고, χ²₁ 분포를 이용해 zₜ = gₜ²/σ²를 정의한다. 정체 조건은 |zₜ‑1| < \(\hat\rho\) 로 변환된다. 최근접 반올림(NR) 하에서는 정체 확률이
\(P_{\text{stall}}^{NR}(\hat\rho) ≈ F_{χ²₁}(1+ \hat\rho) - F_{χ²₁}(\max(0,1‑\hat\rho))\)
으로 근사된다. \(\hat\rho ≥ 1\)이면 거의 모든 업데이트가 차단돼 정체 확률이 1에 수렴한다. 실제 형식별 파라미터를 대입하면 BF16에서는 약 94 %의 정체, FP8·FP4에서는 99‑100 %에 달한다는 이론값이 실험과 일치한다.
**4. 초기 단계와 정체 성장**
EMA가 초기에는 0에 가깝기 때문에 ulp이 상대적으로 작아 정체가 거의 발생하지 않는다. 상태가 성장하면서 ulp이 커지고, 정체 확률이 급격히 상승한다. 이를 ‘시작 윈도우(startup window)’라 부르고, 목표 정체 확률 P₀에 대해 최소 j* (스텝 수) 를
\(j^* = \left\lceil \frac{\log(1‑\phi^*)}{\log β} \right\rceil\)
with \(\phi^* = F^{-1}_{χ²₁}(P₀) / (1+ \hat\rho)\) 로 계산한다. 실험에서는 BF16에서 수천 스텝, FP8에서 수백 스텝, FP4에서는 수십 스텝 정도가 이 윈도우에 해당한다. 또한 초기 정체 비율 P_init이 존재하는데, 이는 양자화 초기 오프셋과 비대칭 rounding에 기인한다.
**5. 정체가 옵티마이저 동역학에 미치는 영향**
정체가 발생하면 해당 스텝은 β = 1과 동일하게 동작한다. 평균적으로 정체 비율 P_stall이 일정하면 유효 감쇠는
\(\beta_{\text{eff}} ≈ 1‑(1‑β)(1‑P_{\text{stall}})\)
가 된다. 이는 EMA의 메모리 길이가 실질적으로 늘어나 적응 속도가 크게 감소함을 의미한다. 두 번째 모멘트(v) 정체는 학습률 스케일링을 방해해 최종 손실을 악화시키고, 첫 번째 모멘트(m) 정체는 방향 정보 손실로 훈련 불안정을 초래한다.
**6. 정체 완화를 위한 리셋 전략**
정체가 지배적으로 되기 전까지 EMA는 유효하게 작동한다. 따라서 ‘리셋(reset)’은 정체가 시작되는 시점 직후에 수행하면 EMA를 초기화해 다시 유효한 시작 윈도우를 얻을 수 있다. 논문은 정체 모델을 이용해 최적 리셋 주기 T* ≈ j* 를 제안한다. 실험에서는 BF16에서는 2 k‑5 k 스텝, FP8에서는 300‑800 스텝, FP4에서는 100‑200 스텝 간격으로 리셋을 적용했다.
**7. 실험 검증**
- **제어 시뮬레이션**: LLaMA‑60M 모델에 대해 업데이트를 임의로 건너뛰는 확률 p를 조절해 정체 효과를 직접 측정. 두 번째 모멘트 정체 비율이 0.8 이상이면 최종 검증 손실이 크게 상승함을 확인.
- **실제 LLM 사전학습**: GPT‑Neo‑2.7B 규모 모델을 BF16, FP8, FP4로 학습하면서 제안된 리셋 스케줄을 적용. 리셋 없는 경우는 초기 10 % 구간만 빠르게 수렴하고 이후 정체로 인해 손실이 정체, 반면 리셋 적용 시 지속적인 수렴 곡선을 보이며 최종 퍼플렉시티가 FP32 대비 1‑2 % 차이로 유지되었다. 메모리 사용량은 FP32 대비 4‑8배 절감되었다.
**8. 결론 및 시사점**
- 저정밀도 EMA는 ‘양자화 게이트’에 의해 단계별로 차단되며, β값이 클수록(β≈0.999) 정체가 심화된다.
- 정체 확률은 형식의 ε와 β에만 의존하므로 모델 크기·데이터 규모와 무관하게 발생한다.
- 정체가 시작되는 시점을 정확히 예측하면, 그 직후에 리셋을 수행해 EMA를 재활성화할 수 있다.
- 이론적 정체 모델은 리셋 주기 설계에 직접 활용 가능하며, 메모리 절감과 성능 유지 사이의 트레이드오프를 효율적으로 관리한다.
본 연구는 저정밀도 옵티마이저 상태 관리에 대한 체계적인 이론과 실험을 제공함으로써, 차세대 대규모 모델 훈련에서 메모리 효율성을 크게 향상시키는 실용적인 가이드를 제시한다.
원본 논문
고화질 논문을 불러오는 중입니다...
댓글 및 학술 토론
Loading comments...
의견 남기기