动机
- 基于
AR
的语言模型通过autoregressive model
将估计文本语料库的概率分布,将似然因式分解为前向乘积或后向乘积,并对每个条件分布进行建模。其通常仅通过单向上下文进行预训练,而下游任务通常需要双向的上下文信息; - 基于
AE
的语言模型通过denoising autoencoder
从损坏的输入中重建原始数据,由于不执行显式的密度估计,因此可利用双向上下文进行重建,并消除了预训练和下游任务的双向信息差异。但其引入了mask
造成预训练和微调时输入的差异,且假定对于给定的unmasked tokens
,每个预测的tokens
彼此独立,该假设过度简化了。
贡献
- 提出一种通用的自回归模型
XLNET
,该方法同时利用了AR
和AE
语言模型的优点,一方面利用双向上下文信息,另一方面不会造成预训练和微调差异,且避免了BERT
关于mask
的独立性假设; XLNET
对Transformer-XL
重参数化,将其segment recurrence mechanism
和relative encoding scheme
集成到预训练中;XLNET
在多个任务上实现了最先进的效果,包括语言理解、阅读理解、文本分类、文档排名等任务。
XLNET
背景
给定文本序列,AR
语言模型通过将最大化似然估计前向自回归分解:
其中,表示由神经网络得到的上下文表征,表示的embedding
。BERT
则基于denoising autoencoder
,首先将中的一部分随机设置为mask
构造得到,设被mask
的tokens
为,其训练目标为从重建:
其中,表示被mask
,表示通过Transformer
将长度为的文本序列映射为隐变量序列。AE
和AR
的预训练语言模型主要差异如下:
Independence Assumption
:BERT
基于所有被mask
的token
独立重构的假设,对联合条件概率进行近似的因式分解,而AR
则使用普遍适用的乘积规则对进行分解,而没有这种独立性假设;Input noise
:BERT
输入包含[MASK]
,在下游任务中未出现会导致预训练和微调的差异,原文以一定概率使用原始token
替换[MASK]
并不能解决该问题,而AR
不依赖任何输入损坏,则不会出现该问题;Context dependency
:AR
中仅以单侧的上下文信息为条件,而BERT
中可以访问双向的上下文信息,允许模型更好地捕获双向上下文信息。
目标函数
使用
Permutation Language Modeling
综合AR
和BERT
方法的优点
借鉴orderless NADE
的方法,提出排列语言建模(Permutation Language Modeling
)的目标函数。对于长度为的序列,有种不同顺序执行自回归分解,若模型参数能够在所有因式分解顺序之间共享,则理想情况下,模型可以学习从双向的所有位置收集信息。
对于长度为的序列索引,假设为所有可能排列的集合,则本文使用的目标函数为:
其中,和分别表示一个排列的第个元素和前个元素。即对于序列,每次采样一个因式分解的顺序,并通过该顺序求解似然概率。由于模型参数在所有因式分解顺序中共享,则理想情况下可以看到序列中的每个可能元素,因此能够捕获双向的上下文。并且该模型基于AR
,自然避免了独立性假设和预训练微调差异的影响。
本文在实现时,保持原始的序列排序及对应于原始序列的位置编码,通过Transformer
中的attention mask
实现因式分解的排列顺序,这样不会改变输入方式,从而不会影响微调时的输入。
网络结构
基于
Target-Aware
表征的Two-Stream Self-Attention
Target-Aware Representation
虽然permutation language modeling
能满足目前的目标,但简单地使用标准的Transformer
并不一定有效。假设使用标准softmax
参数化下一个token
的分布:
其中,表示的隐特征,由mask
的输入经过transformer
得到。不依赖于要预测的token
的位置,但可能存在即使目标位置的不同,其因式分解的结果一致,则经transformer
都预测得到相同的分布,由此无法学习到有效的表征。
即假设两个排列和,满足,但是,则:
即对于不同位置和具有相同的模型预测结果,不符合预期。为了解决该问题,文章重参数化下一个token
的分布,使其能感知目标位置:
其中,为增加目标位置作为额外输入的新的表征形式。
Two-Stream Self-Attention
虽然,target-aware representation
能消除目标预测的模糊性,但如何构建$g_{\theta}\left(\mathbf{x}_{\mathbf{z}_{content
,以提供完整的上下文信息。
为了解决此矛盾,本文使用两组隐表征(即每个词具有两个属性,预测自己时不需要自己的内容信息,预测其他词时作为context
需要内容信息):
content
表征,缩写为,其作用等同于传统的Transformer
,由context
和编码得到;query
表征,缩写为,其只接触context
信息和位置,跟content
无关。
实际计算时,第一层网络的query stream
被初始化为可训练的向量,content stream
被设置为对应的word embedding
。对于第层自注意力层,two stream
表征共享参数进行更新:
如图所示,query stream
使用,但不使用,而content stream
同时使用和,最终使用最后一层的query
表征计算目标函数。可知content
表征与传统self-attention
一致,在微调阶段,舍弃掉query stream
而只使用content stream
,故微调阶段与BERT
一致。
Partial Prediction
全排列导致实验收敛缓慢,故仅预测因式分解顺序中最后的部分tokens
,将以为切分点分为非目标子序列和目标序列,故目标函数是最大化非目标子序列条件下目标子序列的对数似然:
选用作为目标序列,是因为在给定当前因式分解顺序时,其拥有最长的context
。对于非目标的tokens
,则不需要计算query
表征,可以加速和减少运行内存。
借鉴自Transformer-XL
的想法
因为本文使用AR
框架的目标函数,故可引入Transformer-XL
作为预训练框架,并集成相对位置编码(relative positional encoding scheme
) 和片段循环机制(segment recurrence mechanism
)。
segment recurrence mechanism
主要用于解决超长序列的依赖问题,对于特别长的序列会导致丢失一些信息,
Transformer-XL
即将序列分割为多个segment
,计算后面的segment
时依赖前一segment
的隐特征。
假设长序列的两个segment
和,令和分别表示和的排列。基于排列处理第一段segment
,缓存层获得的content
表征,则对于下一个segment
,注意力更新变为:
其中,表示concate
,因为位置编码仅取决于原始序列中的实际位置,则获得后,上式跟无关,使得在不知道前面的segment
时可以缓存和使用之前segment
的隐状态。
relative segment encodings
BERT
在每个位置的word embedding
使用绝对segment
编码,与当前内容在原始序列的相对位置没有关系,导致了位置信息的损失。使用relative segment encodings
引入inductive bias
提升泛化,且对于两个以上输入segment
的任务可以进行微调。
给定序列中的一对位置和,如果和来自相同的segment
,则定义segment encoding
,否则为,其中和是每个attention head
学习的模型参数。即只考虑两个位置是否来自同一segment
,而不是是否来自特定的segment
。引入相对编码后,计算注意力的权重,其中是标准attention
的query
向量,是head
的偏置向量,最后将与传统的注意力权重相加。
与BERT
的比较
预训练阶段与BERT
类似,输入形式为,但未使用NSP
任务,只使用了来自同一上下文的序列。
通过比较目标函数,BERT
和XLNET
均使用部分预测。但BERT
的独立性假设,使得BERT
无法建模预测目标之间的依赖关系。例如输入序列为[New, York, is, a, city]
,假设BERT
和XLNET
均选择[New, York]
为预测目标,以及XLNET
的因式分解顺序为[is, a, city, New, York]
,均最大化目标函数:
即XLNET
能捕获[New York]
的依赖对信息,虽然BERT
也能学习一些依赖对,但XLNET
在相同目标情况下能学习更多的依赖对,包含更密集的训练信息。更一般地,给定序列,定义target-context
对,其中是的上下文。给定目标tokens
和非目标tokens
,BERT
和XLNET
均最大化:
其中,表示中顺序在之前的tokens
。如果且,则XLNET
能覆盖更多的依赖项,则XLNET
能获得更多有效的训练信息,使得获得更好的效果。
训练过程
数据集:BooksCorpus
,English Wikipedia
,Giga5
,ClueWeb 2012-B
和Common Crawl
分词: SentencePiece
训练细节:
batchsize
为,token
长度为,使用Adam
优化器和线性学习率衰减。设置partial prediction
的超参数,微调阶段与BERT
大致相同。本文使用span-based prediction
,即先采样,并选择连续的包含个tokens
的span
作为预测目标,上下文则包含个tokens
。
部分实验结果
Ablation Study:
如图所示:
permutation language modeling
优于denoising auto-encoding
;Transformer-XL
作为backbone
的有效性;memory
(即segment recurrence mechanism
)对长文本任务影响较大;span-based prediction
和bidirectional data
也起着重要作用;next sentence prediction
无明显影响;