Injecting Numerical Reasoning Skills into Language Models を読んだ

目次

リンク
論文:[2004.04487] Injecting Numerical Reasoning Skills into Language Models
コード:GitHub - ag1988/injecting_numeracy: The accompanying code for "Injecting Numerical Reasoning Skills into Language Models" (Mor Geva*, Ankit Gupta* and Jonathan Berant, ACL 2020).

比較手法のMTMSN:GitHub - huminghao16/MTMSN: A Multi-Type Multi-Span Network for Reading Comprehension that Requires Discrete Reasoning
比較手法のNABERT+:GitHub - raylin1000/drop-bert: NABERT model for solving the DROP dataset
評価用データセットDROPの論文:[1903.00161] DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs

概要

  • LanguageModelに数値推論を注入する汎用モデル(GenBERT)と、それをpre-trainingするための学習用データを生成するためのフレームワークを提案
  • 既存の数値推論を評価用データセット(NRoT; Numerical Reasoning over Text)にてSOTAのモデルと同程度の性能を発揮することを示した

数値推論の例

f:id:wwacky:20210214145110p:plain:w400
※ Table 1. Injecting Numerical Reasoning Skills into Language Modelsより引用

数値推論の質問は複数種類あり、Passage(Contextともいう)やQuestionから該当する文字列を抽出すれば良いspanや、数値計算が必要なnumberがある。他にも複数のフレーズの抽出が必要なspansや、日付の計算が必要なdateがある。

提案手法の内容

複数のheadを用いて回答を行う。headの種類は3つ。

  • answer type head: 回答方法を決定する(context span, question span, decodeのマルチクラス分類)
  • two span-extraction heads: ContextかQuestionから回答を抽出する(NER)
  • generative head: 回答を生成する
f:id:wwacky:20210215011642p:plain
※ Figure 2. Injecting Numerical Reasoning Skills into Language Modelsより

モデルのLossには以下を用いる。*1
f:id:wwacky:20210215011807p:plain

EncoderのBERTとの差分

Digit Tokenization(DT)

数値計算用に数値を1文字ずつのwordpieceにする
f:id:wwacky:20210215155508p:plain:w500

Random Shift(RS)

“1086.1 - 2.54 + 343.8”の様な短いテキストを学習に用いるため、数値が先頭に来る場合に数値推論を行うようにオーバーフィットしてしまう。これを避けるためにposition IDをランダムにする*2

学習・推論の全体像
pre-train済みのLanguageModel(BERT)をNumerical Data(ND)とTextual Data(TD)を用いて追加pre-trainする。
数値推論を行う場合はspan extraction headsとgenerative headの両方をfine-tuningする。 SQuADの様なテキストのみの質疑応答の場合はspan extraction headsのみを用いる。

f:id:wwacky:20210214220713p:plain:w400
※ Figure 1. Injecting Numerical Reasoning Skills into Language Modelsより引用

pre-trainingの方法

pre-training用のNumerical Data (ND)の生成

テンプレートを用いて数値演算のテキストを生成する。テンプレートはTable 2を参照。

f:id:wwacky:20210213210218p:plain
※ Table 2. Injecting Numerical Reasoning Skills into Language Modelsより引用
pre-training用のTextual Data (TD)の生成

既往研究*3を用いて学習用データを生成する。生成されるデータの例はTable 3を参照*4

f:id:wwacky:20210213212039p:plain:w400
※ Table 3. Injecting Numerical Reasoning Skills into Language Modelsより引用
追加pre-trainingの方法

追加pre-trainingは学習済みのBERTに対して、NDとTDを用いて行う。ただし、追加pre-trainingは元のLMの言語情報の「破局的忘却」*5を避けるために、MLMも含めたマルチタスクで行う。
Figure 1では追加pre-trainingをND→TDの順に行うかのように記載されているが、実際はmini-batches作成時にそれぞれのデータセットからサンプリングして同時に行う。MLMに使うデータセットは論文中のA.3を参照。

f:id:wwacky:20210215010433p:plain

数値実験

ND・TDによる追加pre-trainの効果

各モデルをDROPのデータでfine-tuningした結果で比較。追加pre-trainなしだとEMが46.1と低い。ND+TDで追加pre-trainすることでSOTAのMTMSNと同程度の性能になる*6
Table 4からLM(追加pre-train時のMLM)やRandom Shiftも効果があることが分かる。

f:id:wwacky:20210214232605p:plain
※ Figure 4. Injecting Numerical Reasoning Skills into Language Modelsより引用

Digit Tokenizationの効果

pre-trainingのstep数毎のaccuracy。Digit Tokenizationが無い場合のみAccが低く、精度改善に寄与している事がわかる。

f:id:wwacky:20210215155626p:plain:w500
※ Figure 4. Injecting Numerical Reasoning Skills into Language Modelsより引用

言語理解能力を失っていないかの確認

数値推論が不要な質疑応答(SQuAD)でfine-tuningして評価を行う。
GENBERTは数値推論を行いつつ、言語理解能力も失わずBERTと同程度であることが分かる。

f:id:wwacky:20210215155724p:plain
※ Table 7. Injecting Numerical Reasoning Skills into Language Modelsより引用

GENBERTの重みは、GENBERT以外のアーキテクチャでも利用できるか?

NABERT+とEfratら(2019)のMS-TAGのエンコーダの初期化にGENBERT+ND+TDの重みを使用。DROP上でのfine-tuning有りで評価した結果、両方とも2ポイントEMが向上する。これにより、GENBERTがBERTの代替として利用できることが示されている。
f:id:wwacky:20210215004251p:plain









補足

pre-training、fine-tuning時のパラメータ

f:id:wwacky:20210215010815p:plain
※ Table 10. Injecting Numerical Reasoning Skills into Language Modelsより引用


DROPの質問種別毎の精度
数値計算(number)では予想通りMTMSNより精度が高い。
spanでもMTMSNより精度が高いが、これは内部的に数値計算を実行した上でspanを回答できるためだと考えられる。
spansは答えが連続しないフレーズのリストの質問だが、MTMSNはGENBERTを大幅に上回る。これはMTMSNはspans用の専用ヘッドを持っているが、GENBERTは標準的なReasoning Comprehension用のヘッドしか持たないためである。

f:id:wwacky:20210213224117p:plain
※ Table 5. Injecting Numerical Reasoning Skills into Language Modelsより引用

数値理解能力の検証
数学の単語問題(MWP)データセットのコレクション(MAWPS)を用いて評価する。fine-tuningは行わずにzero-shotで評価を行う。

GENBERT+ND+TDは、GENBERT(BERTのママ)と比較して劇的に性能を向上させる。GENBERT+NDはGENBERT+TDよりもはるかに優れた性能を発揮し、コンテキストが短い場合のNDの有用性を示している。
MTMSNはGENBERT+ND+TDを上回る。MTMSNは足し算と引き算に特化したアーキテクチャを使用しており、モデル外で計算が行われる場合に適している。

f:id:wwacky:20210214094708p:plain
※ Table 6. Injecting Numerical Reasoning Skills into Language Modelsより引用

次に、演算式の項の数で性能を比較する(図5)。図からすべてのモデルがより複雑な問題に一般化するのに苦労しており、項が4以上になると完全に失敗している。GENBERT+ND+TDの2項と3項の間の性能低下は、GENBERT+NDとGENBERT+TDよりも有意に小さい。このことはNDとTDの両方がロバスト性を向上させるのに有効であることを示唆している。

f:id:wwacky:20210214100923p:plain
※ Figure 5. Injecting Numerical Reasoning Skills into Language Modelsより引用

Error analysis
DROPのdevセットのGENBERT+ND+TDの誤答を分析する。モデルがサポートしていないマルチスパンの解答を持つ問題を除外した上で、GENBERT+ND+TDが誤答したものを100例のランダムサンプリングし、誤答の種類を手動で分類した。具体的な例は以下のTable 11を参照。

f:id:wwacky:20210215002614p:plain
※ Table 11. Injecting Numerical Reasoning Skills into Language Modelsより引用

ほぼ半数のケース(43%)では、事前学習課題(ソートなど)でカバーされていないか、数値的ではない推論スキルを必要としているものだった。もう一つの一般的なケース(23%)は抽出したスパンが長すぎたり、予測値とgold answerと数字が部分的に一致する場合であった。これらのエラーの多くは、事前学習課題を拡張して追加の数値スキルやより大きな数値範囲をカバーすることで対処できると思われる。

*1:このLossだとspan extractionで抽出される文字列をdecoderで生成することも可能だが、span extractionで解くべき問題をdecoderで解いた場合にどういう計算になるのかは不明。またdecoderで解くべき問題の時に正しいspanを選択する確率がどうなるのか不明。

*2:Random Shiftで設定する値はwhen the input length n1 + n2 + 3 < 512, we shift all position IDs by a random integer in \(0, 1, . . . , 512 − \(n1 + n2 + 3\)

*3:Hosseini et al. 2014. Learning to Solve Arithmetic Word Problems with Verb Categorization - ACL Anthology

*4:データ生成は品詞に抽象化したテンプレートに単語を当てはめるような手法を用いている。Figure 3を参照。

*5:Kirkpatrick, 2017 [1612.00796] Overcoming catastrophic forgetting in neural networks

*6:SOTAはBERT-largeをベースにしたMTMSNだが、比較のために今回はBERT-baseをベースにしたMTMSNを比較対象としている