Search
🗼

[논문리뷰] Large Batch Optimization for Deep Learning: Training BERT in 76 minutes

Created
3/7/2021, 12:26:24 PM
Tags
Pretraining
💡
논문 리뷰: Large Batch 를 안정적으로 학습시켜주는 LAMB Optimizer

Abstract

요약

매우 큰 모델을 많은 양의 데이터로 학습시키는 것은 계산적(computationally)으로 어렵고 활발하게 연구가 되고 있다. 특히 큰 배치 단위로 모델을 업데이트하는 것이 효율적인 학습 방법으로 제시 되고 있다. GPU 의 개수(=배치를 키우는 것)에 따라 학습 속도도 linear 하게 빨라 지는 것이 이상적인 목표.
최근 연구에서는 *LARS(Layer-wise Adaptive Learning Rate Scaling) 기법이 큰 배치 환경에서 좋은 결과(ResNet)를 보여주었다. 하지만 BERT 같은 Attention 모델에서는 안정적인 성능을 보여주지 못했다. (* LARS는 모델의 레이어 별로 별도의 learning rate scaling 값을 갖음) 우린 이거 해결 할 거다
Layer-wise Adaptive 기법의 근본적 원리를 탐구했고, 공부한 거 바탕으로 새로운 adaptation 기법을 제안함. = LAMB optmizer 를 맹들었음.
저자가 만든 LAMB로 BERT, RESNET-50 학습을 굉장히 큰 large-batch 로 해 봤는데 하이퍼파리미터(lr, warmup-rate 등) 튜닝 거의 없이 짱짱 잘댐. 특히 BERT를 성능 저하 없이 TPUv3Pod 에서 32k 배치 크기로 76분만에 학습시킬 수 있었음 (*512 토큰 기준 32k, 128 토큰 기준 131k)

왜 Large-Batch 에서 학습이 안되는가?

(Large Batch 와 Converge 간의 Tradeoff)This is particularly problematic with larger mini-batch sizes, because they require higher learning rates to compensate for fewer training updates. But, training frequently diverges when the learning rate is too high, thereby limiting the maximum mini-batch size we can scale up to. → 큰 배치는 몇번 업데이트를 못하니깐, 큰 lr를 가져야 하지만 lr 이 너무 크면 수렴을 안함. 그래서 최대로 늘릴 수 있는 배치 크기가 제한될 수 밖에 없음
For 64K/32K mixed-batch training, even after extensive tuning of the hyperparameters, we fail to get any reasonable result with ADAMW optimizer. We conclude that ADAMW does not work well in large-batch BERT training or is at least hard to tune
You et al., that some layers may cause instability before others, and the “weakest” of these layers limits the overall learning rate that may be applied to the model, thereby limiting model convergence and maximum mini-batch size. → 몇몇의 레이어에서 gradient * lr 크기가 실제 weight 의 크기보다 커지는 불안정한 모습을 확인함.

그래서 어떻게 한다고?

(step 2) layer 별 gradient 를 전체 gradient 의 크기 (L2-norm) 로 nomalization
(step 3-5) Adam의 기반 업데이트 계산
(step 6) Layer 마다 learning rate 를 scaling 하는 term 추가 = LARS

LAMB Optimizer 구조

LAMB 로 학습 시키면?

동일한 512 batch size 에서도 약 1% 가량의 성능 향상, 배치 키우면 3일 → 1시간 10분까지 축소
32k 로 batch 를 늘렸을 때, step 이 15k까지 작아짐에도 불구하고 모델의 성능은 거의 보존됨
TPU (GPU) 늘리면 속도도 어느정도 linear 하게 가져갈 수 있다.

네 줄 요약

Pre-Gradient Normalization (전체 gradient 크기로 각 layer 별 gradient 를 normalizing)
LARS(레이어별로 learning rate scaling) +
AdamW(momemtum, weight decay) +
= LAMB (Layer-wise Adaptive Moments optimizer for Batch training)