Search

202204, Google Research, PaLM: Scaling LM with Pathways

Summary
- blog , paper - GPT-3 style architecture (decoder only) - Training data : En + Multi-lingual (web doc, books, wiki, conversations, GitHub code); PaLM/GLaM/LaMDA은 비슷한 학습데이터 780B tokens (1 epoch 학습) - vocab : SPM “lossless” vocab (whitespace 포함 + OOV fallback Unicode2byte + split numbers into 개별 tokens) - Training resource : 6144 TPU V4 chips (2 cloud TPU v4 Pods; 57.8% FLOPs utilization) - Architecture ◦ SwiGLU ActivationParallel Layers (MLP & Attn): 큰 scale (62B 이상) 에서는 성능 저하 없고 속도 이득만 있음 ◦ Multi-Query Attention : key와 value는 모든 head에 대해서 shared (성능과 학습 속도에서 별 차이가 없지만, autoregressive decoding time에서 매우 큰 cost savings 보여줌) ◦ RoPE Embeddings : Rotary Positional Embeddings (long sequence lengths 에서 더 좋은 성능을 보여줌) ◦ Shared Input-Output EmbeddingsNo Biases : dense kernel 이나 layer norms에서 bias를 제외 (큰 모델에서 학습 안정성 확보) - multi-task 특화 학습 architecture Pathways : Pathways System (Multiple TPU v4 Pods; 6144 chips) 을 적용한 Dense-decoder only Transformer model - 202110 Pathways blog : 구글이 추구하는 모델 방향성 (multi-task 특화 학습 architecture) - Pathways 영상 설명 - 202203 Pathways System paper : Pathway 형태로 학습하기 위한 distributed computation 시스템
작성자 : 김준성 (Linkedin, E-Mail : soulloan@gmail.com)
구글에서 발표한 GPT-3 style architecture (decoder only)의 새로운 LLM PaLM을 소개하는 문서
저자 팀
PaLM builds on top of work by many, many teams at Google and we would especially like to recognize the T5X team, the Pathways infrastructure team, the JAX team, the Flaxformer team, the XLA team, the Plaque team, the Borg team, and the Datacenter networking infrastructure team
Summary
GPT-3 style architecture (decoder only)
Training data : En + Multi-lingual (web doc, books, wiki, conversations, GitHub code)
vocab : “lossless” vocab (whitespace 포함 + OOV Unicode2byte + split numbers into 개별 tokens)
Training resource : 6144 TPU V4 chips (2 cloud TPU v4 Pods; 57.8% FLOPs utilization)
Architecture
SwiGLU Activation
Parallel Layers (MLP & Attn): 큰 scale (62B 이상) 에서는 성능 저하 없고 속도 이득만 있음
Multi-Query Attention : key와 value는 모든 head에 대해서 shared (성능과 학습 속도에서 별 차이가 없지만, autoregressive decoding time에서 매우 큰 cost savings 보여줌)
RoPE Embeddings : Rotary Positional Embeddings (long sequence lengths 에서 더 좋은 성능을 보여줌)
Shared Input-Output Embeddings
No Biases : dense kernel 이나 layer norms에서 bias를 제외 (큰 모델에서 학습 안정성 확보)

소개

[Model Spec.] Training a 540B Parameter LM

Pathways vision
“Enable a single AI system to generalize across thousands or millions of tasks, to understand different types of data, and to do so with remarkable efficiency."
Novel architectural choices and training schemes
Largest TPU-based system configuration (6144 chips) used for training to date
data parallelism at the Pod level across two Cloud TPU v4 Pods
57.8% FLOPs utilization
Transformer bloc의 attention와 feedforward 들이 parallel하게 계산되도록 함
the previous LM models with data parallelism
GLaM, LaMDA : single TPU v3 Pod
Megatron-Turing NLG : pipeline parallelism 2240 A100 GPUs
Gopher : multiple TPU v3 Pods, 4096 TPU v3 chips
Data : English & multilingual datasets
high-quality Web doc.
books
Wikipedia
conversations
Github code (전체의 5%)
Vocab : lossless vocab
whitespace 포함 (코드에서 중요함)
OOV Unicode → bytes로 분리
number → 개별 tokens으로 취급
숫자가 포함된 reasoning task에서 중요한 역할을 하게 됨.

Architecture

SwiGLU Activation
일반적인 MLP (2개의 matrix multiplications) 에 비해 더 많은 3개의 matrix multiplications 보유
하지만, compute-equivalent experiments 에서는 더 좋은 성능 (ReLU variant는 larger dimensions을 가지기 때문)
Parallel Layers (MLP & Attn) in each Transformer block : 큰 scale (62B 이상) 에서는 성능 저하 없고 속도 이득만 있음
기존 : y=x+MLP(LayerNorm(x+Attn(LayerNorm(x))))y = x + MLP(LayerNorm(x+Attn(LayerNorm(x))))
New : y=x+MLP(LayerNorm(x))+Attn(LayerNorm(x)))y = x + MLP(LayerNorm(x))+Attn(LayerNorm(x)))
실험 결과
roughly 15% 더 빠른 학습 속도 at large scales
8B까지는 small degradation 이지만, 62B 부터는 성능 저하 없고, 540B 에서도 성능 저하 없을 것이라고 추론됨. (extrapolated)
Multi-Query Attention : key와 value는 모든 head에 대해서 shared
기존 : input → key, query, value tensors [k,h][k, h] where hh is the attn head size
New : input → query [k,h][k, h], but key, value [k,1][k,1] (shared for each head)
실험 결과
성능과 학습 속도에서 별 차이가 없음.
autoregressive decoding time에서 매우 큰 cost savings 보여줌 (한번에 single token만을 decoding 하기 때문)
RoPE Embeddings : Rotary Positional Embeddings
기존 : absolute or relative position embeddings
New : (Rotary; fixed)
long sequence lengths 에서 더 좋은 성능을 보여줌
Shared Input-Output Embeddings
No Biases : dense kernel 이나 layer norms에서 bias를 제외
큰 모델에서 학습 안정성을 높임
Vocab : SPM with 256k
학습데이터로부터 추출하는 것이 학습 효율성 증가 시킴
lossless & reversible : whitespace 도 encoding (코드 데이터에서 중요)
OOV : UTF-8 bytes로 쪼갠 뒤에 bytes fallback encoding
Number는 독립 token으로 쪼개서 encoding
숫자가 포함된 reasoning task에서 중요한 역할을 하게 됨 : GSM8k 성능 ()

Model Scale Hyperparameters

8B, 62B, 540B
Standard dense Transformer : # of FLOPs per token은 # of parameters와 비슷함

Training Dataset : En + Multi-lingual

LaMDA & GLaM 과 같은 데이터 : a high quality corpus of 780B tokens
→ 세 모델 다 1 epoch 씩만 학습 (shuffled but identically for all models, choose the mixing proportions)
1.
Webpage → quality score from classifier → filtered webpages (score에 비례하게 sampling)
2.
books
3.
Wikipedia
4.
news articles
5.
source code : 196GB
Github으로부터 추출
copyleft licenses는 제외
Levenshtein distance로 file간의 비교를 통해 중복 제거
24 개의 일반적인 프로그래밍 언어에 대한 확장자로 Filtering
6.
social media conversations

Training Resources

2 TPU v4 Pods (model & data parallelism) : 3072 TPU v4 chips in each Pod attached to 768 hosts
6144 TPU V4 chips (2 cloud TPU v4 Pods; 57.8% FLOPs utilization)
pipeline parallelism or DCN
LaMDA & GLaM : only single TPU w/o pipeline parallelism or DCN
Megatron-Turing NLG 530B : 2240 A100 GPUs using model/data/pipeline parallelism
Gopher : 4 DCN-connected TPU v3 Pods (each with 1024 TPU v3 chips)
pipeline parallelism 문제점 → pipeline-free training 을 하고자 함
minibatch를 여러개의 microbatches로 나눈 뒤, pipeline을 형성하여 최대한 utilization을 높이려하지만, 여전히 idle 문제 발생 (bubble 공간; 처음과 끝 부분)
매우 큰 memory bandwidth가 필요 : weights를 메모리로부터 reloading 하는 과정을 매 micro-batch마다 진행함
Parallelism Design
standard within-pod data and model parallelism in 1 TPU v4 Pods (3072 v4 chips)
12-way model parallelism
256-way fully sharded data parallelism
forward : “data parallel axis에 대해 weight gathering” & “one fully sharded activation tensor 만 각각의 layer에 대해 저장”
backward : 각 layer별로 “나머지 activation은 rematerialized 함” (like gradient checkpoint??)
Pathways system (client-server architecture on multi-pods)
[client] 하나의 python client 에서 training batch의 반씩을 각각의 pod server에 보냄
[gradient calculation A] 각 pod server는 JAX/XLA 에 의해 forward & backward 진행 (standard within-pod data and model parallelism)
within-pod gradient reduction 포함
[client-server scheduling cost] 각각의 pod은 asynchronous gang-scheduling을 담당하는 per-pod schedulers를 보유하고 있어서, client → server로의 JAX/XLA work의 dispatch latency를 줄임
[data transfer cost] 각각의 pod은 a sharded-dataflow execution model을 통해 data transfer에 대한 cost를 분할시킴
[gradient transfer] pod server는 다른 pod server로 subgraph (gradient transfer) 를 전달
[transfer cost] 총 1536개의 hosts(1 pod 당 768 hosts) 가 2개의 pods에 걸쳐서 gradient를 주고 받음.
model-sharded parameters에 대해 각 host마다 필요한 parameter만 받으면 되기에, 1:1 transfer만 필요함
계산이 끝나면 한번에 이루어지기에, 매우 큰 (very bursty workload) 작업량이 data-center-network에 부여됨
대략 1:1 transfer (각각의 host pair) 는 1.3 GB 정도의 gradient를 매 training step마다 주고 받음 → 전체 통신량은 81 Tbps임
[optimal DCN link] 이를 극복하기 위해, 하나의 gradient transfer를 여러개의 작은 flow로 분할하고, 이를 다양한 DCN links set을 통해 전달 → 1.95x relative to the throughput (즉, 2배로 batch를 늘렸을 떄 single에 비해 97% perfect weak scaling)
3%는 backward pass 와 cross-pod gradient reduction에서 a lack of overlap으로부터 발생 → 미래에는 이것도 해결할 예정
[optimizer update B] 각 pod server는 자신의 gradient와 전달받은 gradient로 optimizer update 진행
Training Efficiency (46.2% Model FLOPS utilization)

Training Setup

weight initialization : “fan-in variance scaling” WN(0,1/nin)W \sim N(0,1/\sqrt {n_{in}})
embeddings : EN(0,1)E \sim N(0,1)
layer norm이 embedding에는 적용되지 않으므로 초기화를 layernorm 적용된 듯 적용
shared embedding을 활용하므로, pre-softmax output logits 에는 1/n1/\sqrt {n} scale을 적용함 (nn은 emb size) → embedding의 variance가 1/n이 아닌데, 왜???
Optimizer : Adafactor w/o factorization
equivalent to “Adam” with “parameter scaling” (lr을 parameter의 root-mean-square만큼 scaling 함)
weight 초기화를 1/n1/\sqrt {n}으로 하기 때문에, Adam의 manual scaling down 과 비슷함
parameter scaling은 parameter 별로 서로 다른 scale로 학습되게 하기에, 장점이 될 수 있음
the embeddings and layer norm scales은 서로 다른 scale을 가지게 되고, “learning rate scaled down at the same rate” 을 방지함
lr : 0.01 (the first 10,000 steps)
decayed at a rate of 1/k1/\sqrt {k} (kk the step number)
momentum β1=0.9\beta_{1}=0.9
2nd order moment interpolation value β2=1.0k0.8\beta_{2}=1.0-k^{-0.8}
이게 0.99로 고정하는 것보다, 큰 LM을 학습하기에 적합함
초기에는 작은 값 → 나중에는 큰 값 : 초기에는 현재 값에 집중 → 나중에서 moment 평균에 집중
이유 : 희귀한 embedding tokens은 poorly estimated second moments를 가지게 되기 때문에.. (over shorter windows)
gradient clipping : 1.0
weight decay : lr2.0lr^{2.0}
모델 사이즈가 굉장히 크기에 weight decay 값을 매우 작게 잡음
Loss function
label smoothing 활용 안함
softmax normalizer를 위한 loss 추가 : 104log2(Z)10^{-4}log^2(Z)
softmax의 normalizer Z가 1에 가까워지도록
Sequence length : 2048
학습 examples을 모두 concat 시키고, 2048 tokens 단위로 쪼개서 padding 없게 만듦
단점 : examples들이 중간에 짤림
example들은 중간 [eod] token을 활용하여 구분 됨
bs
540B : 512 (1M tokens; step 50k) → 1024 (2M tokens; step 115k) → 2048 (4M tokens; step 255k)
이유
1.
작은 batch가 more sample efficient (초기 학습시에는 본 token에 대해서 더 좋은 loss를 가지게 됨)
2.
큰 batch는 더 좋은 gradient estimates를 보유함
추가적으로, 큰 batch를 쓸수록 TPU 활용효율이 증가
Bitwise determinism (Reproducible)
1.
JAX + XLA + T5X를 통해 random 통제
1~17000 까지 진행하는 것과, 중간 15000~17000까지 진행하는 것의 결과가 정확히 일치
2.
a deterministic dataset pipeline : 오로지 step number의 function으로 데이터 batch 추출 function을 구축
where the shuffled data is written out in a random-access format
Dropout
PT : w/o dropout
fine-tuning : 0.1 dropout

Training Instability

gradient clipping 이 있음에도, 학습 동안 20번 정도 loss spike 발생
작은 모델에서는 발생하지 않는 것이, 큰 모델에서는 highly irregular intervals로 발생
학습 비용이 비싸서 이를 해결하기 위한 a principled strategy를 만들지는 않음
해결책
100 step 전 checkpoint로부터 re-start
roughly 200~500 data batch를 skip
이 연구진들은 “bad data” 때문에 발생한 것이라고 생각하지 않는다고 함.
여러번 실험해봤으나, spike 주변 batch들을 다른 checkpoint로부터 활용해봤으나 문제가 발생하지 않았기에..
결론은 특정한 state에서 특정한 batches 들이 존재할 때 발생하는 것으로 고려됨.

[Results] Language Understanding and Generation

29 English NLP few shot tasks

29개 중에 28개의 task에서 GLaM, GPT-3, Megatron Turing NLG, Gopher, Chinchilla, LaMDa 를 압도
QA (open-domain closed-book variant)
cloze and sentence-completion tasks
Winograd-style tasks
in-context reading comprehension tasks
common-sense reasoning tasks
SuperGLUE tasks
NLI tasks

Multilingual NLP benchmarks

학습 데이터의 22%만 non-English인데도, multilingual NLP benchmarks에서 좋은 성능을 보임

BIG-bench (150 new LM tasks by Google) :

58개의 subset 에 대해 실험 진행
Gopher & Chinchilla 와 비교시, log-linear 형태로 파라미터 수에 따라 성능이 상승하는 것을 확인.
performance improvements from scale have not yet plateaued
5-shot PaLM 540B 모델이 Human Avg. 보다 성능이 좋음

Big-bench Examples

Examples that showcase PaLM 540B 1-shot performance on BIG-bench tasks: labeling cause and effect, conceptual understanding, guessing movies from emoji, and finding synonyms and counterfactuals.
cause and effect 구별
conceptual combinations 이해
emoji로부터 movie를 예측
synonym game
Counterfacuals

Reasoning

Combining model scale with chain-of-thought prompting, it shows breakthrough capabilities on reasoning tasks that require multi-step arithmetic or common-sense reasoning
참고 :
GSM8k 성능 ()
8-shot prompting : 58% 성능
이전 SOTA : 55% fine-tuned GPT-3 175B with 7500 problems & an external calculator/verifier
9-12 year olds 들의 평균 : 60%
특히, “separate encoding of digits in subword” 가 성능향상에 큰 영향력을 가짐
어려운 reasoning 에도 적합 (multi-step logical, world knowledge, deep NLU)
explicit explanations for scenarios that require a complex combination
multi-step logical inference, world knowledge, deep learning understanding 들이 필요한 복잡한 조합의 task에도 좋은 성능을 보임
예시: high quality explanations for novel jokes not found on the web. (a 2-shot prompts example)

Code Generation

few-shot 성능
5% 정도의 데이터만 code를 포함시켰지만, 우수한 성능
에서 큰 모델이 좀 더 sample efficient 하다는 결과를 뒤받침 (NL data와 programming language data로부터 transfer learning 을 효과적으로 진행)
few-shot 성능이 fine-tuned Codex 12B 과 비슷함
50배 적은 training data를 활용했음에도.. 비슷
Examples
Examples of a fine-tuned PaLM 540B model on text-to-code tasks, such as GSM8K- Python and HumanEval, and code-to-code tasks, such as Transcoder
[PaLM-Coder] finetuning PaLM 성능 (on a Python-only code dataset)
code repair task : DeepFix
a complie rate of 82.1%
기존 SoTA : 71.7% (Stanford, Break-It-Fix-It)
compiler diagnostic messages as a graph → GNN → predict a diff (seq. of tokens & pointers to code location)

[Ethical Consideration]

가능한 위험을 알기위해 paper 및 모델을 공개할 때, model cards / datasheets / Responsible AI benchmark results 공개
biases and risks 에 대한 모델 결과 및 데이터셋에 대한 철저한 분석 포함
domain-, task-specific analysis → truly calibrate, contextualize, mitigate possible harms
topics of ongoing research
risks & benefits 에 대한 더 깊은 이해
악용을 막기 위한 scalable solution을 개발하는 것