AI-LAB/논문리뷰

[3] Switch Transformers: Scaling to Trillion Parameter Modelswith Simple and Efficient Sparsity

JS LAB 2025. 3. 12. 13:25
728x90
반응형

GShard

  • MoE Tranformer encoder with device placement
  • limitations
    • tok-k experts lead to load imbalance problem for training
    • expensive network routing
    • gating function should not choose the same few experts

 

Switch Transformer

Contributions

  • Switch Transformer simplifies and improves over MoE.
  • Scalilng properties and a benchmark against T5 model, with 7x + pre-training speedups and same FLOPS over token.
  • Successful distillation which preserves 30% of quality gain while mode size reduced by up to 99%.
  • Improved pre-training and fine-tuning techniques: selective precision, initialization, selective dropout

 

Switch(Top-1) Routing

Switch layer : simplify MoE layer, preserves model quality and reduces routing computation and performs better.

  • Router computation is reduced
  • Batch size of each expert halved
  • Routing implementation is simplified and communication costs are reduced

 

Sparse Routing

Distributed implementation

하이퍼 파라미터를 설정해줘서 experts의 수용 토큰 갯수를 제한함.

 

 

Differentiable load balancing loss

load balancing problem

  • 특정 전문가에게만 토큰이 몰리는 현상 (Expert Overloading)
  • 일부 전문가가 사용되지 않음 (Expert Underutilization)
  • 학습 불안정성 증가

-> loss를 추가해줘서 load balancing 문제를 완화함, routing softmax 이후 선택될 experts 의 최대값을 완화시켜 여러개의 experts에 토큰이 잘 분산되게 하는 loss

 

Mesh tensorflow

 

 

 

자료형에 따른 성능차이

Selective precision

..수도코드...

 

 

 

Routing Computation

Switch Tranformer와 같은 MoE 모델에서 Token을 적절한 Expert network에 할당하는 과정이다.

라우팅 

 

라우팅 연산이 필요한 이유

  • 기존 Transformer는 모든 입력 토큰이 동일한 네트워크 파라미터를 거쳐 연산됨.
  • MoE 구조에서는 각 토큰을 특정한 Expert에게 동적으로 배정함.
  • 하지만 어떤 전문가가 가장 적합한지 결정해야 하는 과정(라우팅)이 필요함.

 

Switch Transformer의 라우팅 방식

  • 각 입력 토큰 xx 에 대해, Router가 xx 를 분석하고 가장 적합한 Expert 하나를 선택함.
  • 일반적인 MoE 모델에서는 두 개 이상의 전문가에게 분배하는 경우가 있지만, Switch Transformer에서는 단 하나의 전문가만 선택하여 연산을 수행하도록 단순화.
  • 라우터는 Softmax 확률 분포를 계산하여 각 전문가에 대한 확률을 매기고, 가장 높은 확률을 가진 전문가에게 해당 토큰을 보냄.

 

Switch Transformer에서 bfloat16과 float32가 어떻게 사용되나?

  1. float32 사용 방식
    • 모든 연산을 32비트 부동소수점(float32)으로 수행.
    • 높은 정밀도를 유지하여 학습이 안정적.
    • 그러나 메모리 사용량이 크고 속도가 상대적으로 느림.
  2. bfloat16 사용 방식
    • 모든 연산을 bfloat16으로 수행하여 속도와 메모리 효율성을 증가시킴.
    • 그러나, 라우팅 연산에서 정밀도 손실로 인해 학습 불안정성이 발생함.
    • 일부 경우, 학습이 발산(diverge) 하는 문제 발생.
  3. Selective Precision 기법 (bfloat16 + float32 혼합 사용)
    • 대부분의 연산을 bfloat16으로 수행하여 속도와 메모리 절약.
    • 라우팅 연산만 float32를 사용하여 안정성을 유지.
    • 결과적으로 float32만큼 안정적이면서도 bfloat16과 유사한 속도를 제공.

 

여기서 말하는 실질적인 연산이란 어떤 것들이 있는가?

1. Routing 연산

  • Router 행렬 연산
  • softmax 연산
  • 라우팅 결정(Argmax or Sampling)

2. Expert 연산

  • 1차 행렬 곱
  • 활성화 함수 적용
  • 2차 행렬 곱
  • 출력 결합 및 최종 계산
  • 손실 함수 계산
  • 그래디언트 계산

 

 

 

 

 

 

 

 
728x90
반응형