AI-LAB/논문리뷰
[3] Switch Transformers: Scaling to Trillion Parameter Modelswith Simple and Efficient Sparsity
JS LAB
2025. 3. 12. 13:25
728x90
반응형
GShard
- MoE Tranformer encoder with device placement
- limitations
- tok-k experts lead to load imbalance problem for training
- expensive network routing
- gating function should not choose the same few experts
Switch Transformer
Contributions
- Switch Transformer simplifies and improves over MoE.
- Scalilng properties and a benchmark against T5 model, with 7x + pre-training speedups and same FLOPS over token.
- Successful distillation which preserves 30% of quality gain while mode size reduced by up to 99%.
- Improved pre-training and fine-tuning techniques: selective precision, initialization, selective dropout
Switch(Top-1) Routing
Switch layer : simplify MoE layer, preserves model quality and reduces routing computation and performs better.
- Router computation is reduced
- Batch size of each expert halved
- Routing implementation is simplified and communication costs are reduced
Sparse Routing
Distributed implementation
하이퍼 파라미터를 설정해줘서 experts의 수용 토큰 갯수를 제한함.
Differentiable load balancing loss
load balancing problem
- 특정 전문가에게만 토큰이 몰리는 현상 (Expert Overloading)
- 일부 전문가가 사용되지 않음 (Expert Underutilization)
- 학습 불안정성 증가
-> loss를 추가해줘서 load balancing 문제를 완화함, routing softmax 이후 선택될 experts 의 최대값을 완화시켜 여러개의 experts에 토큰이 잘 분산되게 하는 loss
Mesh tensorflow
자료형에 따른 성능차이
Selective precision
..수도코드...
Routing Computation
Switch Tranformer와 같은 MoE 모델에서 Token을 적절한 Expert network에 할당하는 과정이다.
라우팅
라우팅 연산이 필요한 이유
- 기존 Transformer는 모든 입력 토큰이 동일한 네트워크 파라미터를 거쳐 연산됨.
- MoE 구조에서는 각 토큰을 특정한 Expert에게 동적으로 배정함.
- 하지만 어떤 전문가가 가장 적합한지 결정해야 하는 과정(라우팅)이 필요함.
Switch Transformer의 라우팅 방식
- 각 입력 토큰 xx 에 대해, Router가 xx 를 분석하고 가장 적합한 Expert 하나를 선택함.
- 일반적인 MoE 모델에서는 두 개 이상의 전문가에게 분배하는 경우가 있지만, Switch Transformer에서는 단 하나의 전문가만 선택하여 연산을 수행하도록 단순화.
- 라우터는 Softmax 확률 분포를 계산하여 각 전문가에 대한 확률을 매기고, 가장 높은 확률을 가진 전문가에게 해당 토큰을 보냄.
Switch Transformer에서 bfloat16과 float32가 어떻게 사용되나?
- float32 사용 방식
- 모든 연산을 32비트 부동소수점(float32)으로 수행.
- 높은 정밀도를 유지하여 학습이 안정적.
- 그러나 메모리 사용량이 크고 속도가 상대적으로 느림.
- bfloat16 사용 방식
- 모든 연산을 bfloat16으로 수행하여 속도와 메모리 효율성을 증가시킴.
- 그러나, 라우팅 연산에서 정밀도 손실로 인해 학습 불안정성이 발생함.
- 일부 경우, 학습이 발산(diverge) 하는 문제 발생.
- Selective Precision 기법 (bfloat16 + float32 혼합 사용)
- 대부분의 연산을 bfloat16으로 수행하여 속도와 메모리 절약.
- 라우팅 연산만 float32를 사용하여 안정성을 유지.
- 결과적으로 float32만큼 안정적이면서도 bfloat16과 유사한 속도를 제공.
여기서 말하는 실질적인 연산이란 어떤 것들이 있는가?
1. Routing 연산
- Router 행렬 연산
- softmax 연산
- 라우팅 결정(Argmax or Sampling)
2. Expert 연산
- 1차 행렬 곱
- 활성화 함수 적용
- 2차 행렬 곱
- 출력 결합 및 최종 계산
- 손실 함수 계산
- 그래디언트 계산
728x90
반응형