Search
🔁

[논문리뷰] REALM: Retrieval-Augmented Language Model Pre-Training

Created
3/7/2021, 12:26:24 PM
Tags
Pretraining
💡
논문 리뷰: Embedding 기반의 Retrieval을 위한 Pretrainig 테스크

소개

Language Model에서는 학습된 세상의 지식을 파라미터(parameter) 안에 암시적으로 저장하고 있지만, 각 지식이 어디에 어떻게 저장되어 있는지 구분해 내는 것은 불가능에 가깝습니다. 또한 모델의 사이즈에 의해서 저장 공간이 결정되기 때문에, 단순히 더 많은 지식을 담기 위해서 위해 모델의 크기를 키우는 것은 메모리, 연산량 관점에서 비효율적입니다. 이러한 한계점을 해결하기 위해서 구글 리서치의 연구진들로 부터 제안된 논문에서는 Retrieval-Augmented Language Model (REALM)이라고 하는 새로운 사전학습 방법을 제안합니다.
제안하는 방법은 모델의 파라미터에 지식을 저장하는 방법이 아닌, 별도의 모델이 어떤 지식을 retrieve 할지 결정하고,inference 시에 retrieve 된 지식을 사용하도록 하는 방법을 사용합니다. REALM의 핵심 아이디어는 retrieval된 결과가 language model의 PPL을 향상시키는데 도움을 준다면 보상(reward)를 주고, 또는 부정확한 정보 또는 의미없는 결과를 retrieve 했다면 처벌(penalize)를 시키는 방식으로 retrieval 모델을 학습하는 것입니다. 예를 들어서 “여러 layer의 transformer 를 이용하는 ___ 는 여러 NLU테스크에서 우수한 성능을 보여주었다”라는 문장에서 모델이 빈칸을 추론해야 한다고 할 때, retrieval 모델이 reward를 받기 위해서는 “BERT는 transformer를 기반으로 한 언어 이해 모델이다”와 같은 문서를 retrieve 해야 하는 것입니다. 저자들은 위와 같은 행위가 “retrieve-then-predict” 구조의 모델링으로 달성될 수 있음을 주장하며, 구체적으로 latent language model과 marginal likelihood 를 사용하는 방법을 제안합니다.
하지만 Large-Scale의 neural retrieval module을 pre-training 과정에서 사용하는 것은 연산적으로 보았을 때 어려움이 많은 시도입니다. 왜냐하면 매 스탭마다 수백만개의 후보 문서들을 retrieval 후보로 고려해야 하고 각 decision에 대해서 back-propagation을 수행해야 하기 때문입니다. 이 문제를 해결하기 위해서 본 논문에서는 retrieval 모델이 비동기적(asynchronously)으로 업데이트 될 수 있고, 후보 document 들을 미리 caching 해 놓을 수 있으며, Maximum Inner Product Search(MIPS)으로 바로 retrieval을 수행할 수 있도록 하는 방법을 대안으로 제시합니다. 사실 이러한 접근은 최근에 리뷰를 진행했던 Lee et al., 2020에서도 유사한 제약조건을 볼 수 있었는데, 최근에 발표되는 Large Scale Retrieval 연구들의 필수 제약 조건이라고도 볼 수 있을 것 같습니다.

접근 방법

제안되는 방법은 크게 Retrieval과 Knowledge Augmented Encoder 모듈로 나누어져 있는데요, 이를 수식적으로 표현하면 다음과 같습니다. Z는 Retrieval 될 수 있는 후보군이고, x와 y는 각각 masking 된 입력과 masking 된 부분의 원본 토큰입니다. 즉 주어진 입력에 대해서 관련된 문서를 찾는 p(z|x)와, retrieval된 문서를 고려해서 mask된 토큰을 추론하는 p(y|z,x)가 합쳐진 형태입니다.
p(yx)=zZp(yz,x)p(zx)p(y|x) = \sum_{z \in Z}p(y|z,x)p(z|x)

Knowledge Retrieval

우선 Retrieval 모델을 간략하게 설명하면, document 와 query를 BERT의 입력으로 넣고 [CLS]토큰 위치의 pooled-output representation을 사용합니다. 두 representation간의 유사도는 dot-product를 이용해 계산되며, 전체 retrieval 후보군 중에서 실제 관련이 있는 문서가 가장 높은 score를 갖도록 cross-entropy를 이용해 학습하게 됩니다. 위에 설명드린 내용을 수식적으로 표현한다면 아래와 같습니다. 물론 사전학습시에 대량의 retrieval 후보군들에 대한 score를 계산하는 것은 연산량적으로 불가능하기 때문에, ANN search 를 통해서 retrieve 된 top-k개의 후보군에 대해서만 score 계산을 수행하였습니다.

Knowledge Augmented Encoder

두번째로 masking된 입력과 관련이 있는 문서들을 retrieval 하고, 관련 문서들을 참조해서 mask에 들어갈 적절한 토큰을 예측해야 하는 단계로 넘어갑니다. retrieval 된 문서를 참조하는 방법으로는 단순하게 [SEP]토큰으로 masking 된 입력과 관련 문서를 concat하도록 합니다. 그런후 기존의 BERT와 동일한 방식으로 masked lm을 학습시키는 방식을 사용합니다. 이러한 방법은 위에서 언급했던 retrieval의 성능이 MLM의 성능에 영향을 주기 떄문에, masked LM을 학습시키는 메커니즘이 자연스럽게 retrieval의 성능을 높일 수 있게 됩니다.

What does the retriever learn?

(사실 아직까지 이 부분에 대해서 완벽하게 이해가 되지는 않았습니다.)
하지만 이러한 방법은 직접적으로 retrieval에게 reward와 penalty를 준다고 볼 수 없습니다. 때문에 저자는 다음과 같은 방법을 이용해서 retrieval에 signal을 주는 방법을 제안합니다. 이에 대한 구체적인 수식은 다음과 같습니다. 이 방법은 r(z)에 따라서 gradient가 retrieval의 f(x,z)를 변경할 수 있도록 유도하는 방법입니다. - r(z)가 양수인 경우에 f(x,z)는 향상되며, 음수인 경우에는 f(x,z)가 줄어들게 됩니다.
만약 모델이 정답에 가깝다면 p(y|z,x)>p(y|x)일 것이기 때문에 r(z)는 양수가 될 것이고 이는 f(x,z)가 커지게 만 것입니다. 이때 p(y|x)는 랜덤한 문서를 주었을 때의 MLM이 성공할 확률이며, p(y|z,x)는 관련있는 문서를 주었을 때의 MLM 성공 확률일 것입니다. 따라서 retrieval 모델이 보다 관련있는 문서를 뽑을 수록, retrieval 모델에게 긍정적인 시그널을 보내게 됩니다.
(참고)다른 블로그 리뷰 에서는 다음과 같이 리뷰를 해 주셨습니다.

Implementation Details

Salient span masking: 사전학습 하는 과정에서 REALM이 mask token을 예측하기 위해서 별도의 지식이 필요로 하는 예제들을 주로 고려하기를 희망하였습니다. 하지만 기존의 random masking은 주로 local context를 고려하는 것에 초점이 맞춰져 있었습니다. 논문의 저자들은 이를 위해서 사전에 학습된 NER 모델을 이용해서 Named Entity로 판단된 span을 모두 masking 하는 방식을 선택하였습니다. 만약에 문장에 여러 span이 존재한다면 그 중 하나만을 선택하여 masking을 수행합니다.
Null document: 위와 같은 masking을 진행한다고 해서 모든 span이 외부 지식이 필요로 하는 것은 아닐 것입니다. 때문에 사전학습시에 아무런 document도 주지 않는 예제들을 포함시키었습니다.
Prohibiting trivial retrievals: 만약에 사전 학습에 사용된 질의 코퍼스와 지식을 참조하는 코퍼스가 동일하다면, 너무 분명한(trivial) 후보들이 뽑힐 가능성이 있습니다. Retrieval 모델이 의미적인 정보를 고려하지 않는 exact match형태의 모델이 될 수 있는 문제가 있습니다. 때문에 corpus내에서 너무 문장이 일치하는 knowledge-query 페어를 배제하고 사전학습을 수행하였습니다.
Initialization: 사전학습에 만약 전혀 학습되지 않는 retrieval 모델을 사용한다면 MLM에 있어도 좋은 영향을 끼치지 않을 것입니다. 때문에 본 보델에서는 Inverse Cloze Task(ICT)로 사전 학습된 retrieval 모델을 사용하였습니다. 또한 보다 빠른 학습을 수행하기 위해서 knowledge-augmented encoder 역시 BERT-base를 pretrained 모델로 사용하여 학습을 수행하였습니다.
사전학습을 64개의 TPU를 사용하여 512-batch로 200k step을 수행하였습니다. Retrieval 모델의 index 업데이터는 16개의 별도의 TPU를 사용해 비동기적으로 진행하였습니다. 사전 학습시 top-8개의 retrieval document 를 사용하였고, null document도 추가적으로 사용하였습니다.

Experiment

저자들은 Open-QA 데이터셋에 대해서 기존의 two-stage 기반 베이스라인 모델들과 제안하는 방법의 성능을 비교 분석하였습니다. 위 표를 보면 알 수 있듯이, 기존의 방법들에 비해서 큰 성능 향상을 보여 주었으며 동일한 Dense Retrieval + Transformer 접근을 사용한 ORQA보다도 약 7% 높은 성능 향상을 보여 주었습니다. 재미있는 점은 대량의 파라미터를 보유한 T5모델을 QA에 사용한 성능과도 비교 분석을 수행하였는데, 약 30배나 작은 모델임에도 6%나 되는 높은 성능 향상을 보여주었습니다. 테이블에 명시 되어 있는 각 테스크는 다음과 같습니다. ( NQ: NaturalQuestions-Open, WQ: WebQuestions, CT: CuratedTrec)

결론

MLM 학습시에 Retrieval 모델까지 back-prob이 되도록 REALM모델 구조를 제안함
사전 학습시에 지속적으로 retrieval 모델을 업데이트하고, 관련 후보를 매 스텝마다 뽑음
두 모델을 같이 학습시켰더니 성능을 크게 향상시킬 수 있었음