大模型长文本建模的难点与方案

本文于7月19日在鹅厂内部发表,为了方便引用进行公开。欢迎各位同行指正。1. 引言

去年年底 ChatGPT 的问世掀起了预训练语言大模型(Large Language Model,LLM,下文简称大模型)的热潮并持续至今,包括国内外各大公司在大模型方向的技术探索,以及各行各业的用户基于大模型的应用探索,以及一时间的洛阳"卡"贵。

在诸如金融、司法、科研等特定领域或特定场景,经常需要对几十页甚至几百页的报告进行理解,完成长文档摘要总结、基于长文档的阅读理解、基于长文档的问答等任务,或者需要同时参考几十篇文章,完成多文档的信息对比分析以及关键信息总结。因此,长文本建模能力是大模型在这些领域或场景下顺利应用的前提条件。

除了上述特定领域,在通用领域长文本建模能力也具有极其重要的意义。对大模型而言,具备更长文本的建模能力意味着模型可以观察到更长的上下文,可以避免因观察窗口限制导致重要信息的丢失。大模型的一个重要能力是上下文学习能力(In-Context Learning),模型支持更长的上下文意味着通过上下文学习时可以给模型输入更多 few-shot 样例,或者使得以往因长度限制无法进行 few-shot 甚至 one-shot 的任务可以通过 few-shot 进行上下文学习并取得更好的效果。

目前,部分关注度较高的模型可接受的上下文长度如下表所示:

公司/机构模型上下文 tokens
OpenAIGPT3.54k ~ 16k
OpenAIGPT48k ~ 32k
AnthropicClaude100k
GoogleBard4k?
MetaLLaMA2k
MetaLLaMA24k
2. 长文本建模的难点

2017年出现的 Transformer [1] 结构展现出了强大的语言及多模态的建模能力,因此在预训练语言模型上被迅速采用并成为主流的深度网络结构。本文所讨论的长文本建模,只涉及在Transformer 结构的模型上的难点,不涉及其他深度学习模型结构(如CNN、RNN等)。

按照系统论的观点,结构决定功能,因此任何系统的特性都需要首先从其结构去分析。

Vanilla Transformer 结构如下图:

在Transformer中 Attention 计算是关键。 下图展示了Attention 和 MultiHead Attention:

相应的计算方法分别为:

算法复杂度为:

上面表格中, n 是输入长度, d 是Embedding维度。

简而言之,self-attention 使得同一上下文的每一个token都能够观察到上下文任意位置的 token(decoder 的 self-attention 是每个token只能观察到上下文任意历史位置的token),这一结构特点使得Transformer 较之CNN、RNN等模型结构理论上显著提升了长距离依赖的捕捉能力,无数实验也证实了这种结构可以提升最终的效果。付出的代价就是,与此同时计算复杂度也增长为平方级。计算时间和计算资源的制约是 n 无法迅速增大的直接因素,也就是在平方级复杂度制约下,模型上下文难以随心所欲地增长。

3. 长文本建模的方案

针对长文本建模的难点,目前主要有3种不同的解决方案。

第一种方案是借助模型外部工具辅助处理长文本或者利用外部记忆(external memory)存储过长的上下文向量,可以称为外部召回的方案;

第二种方案是利用模型优化的一般方法;

第三种方案是优化Attention的计算。

下面对这些方案进行总结,主要介绍各类方法的思想,具体效果可以查看相关论文。

3.1. 外部召回的方法

借助模型外部的工具辅助处理长文本或者利用模型外部记忆(external memory)存储过长的上下文向量,是一种性价比较高的方案。主要思路是,将长文本切分为若干份长度适合的短文本片段并放入数据库或检索系统中。这里的适合长度是指模型能够轻松处理的合理长度。模型在处理长文本时,根据具体问题对外部的短文本片段进行检索,得到最相关的一个或多个短文本片段,每次只加载所需要的短文本片段,从而避开了模型无法一次读入整个长文本的问题。

该方案的整体流程如下图所示:

上述流程可以归纳为切分、索引、查询、生成4个阶段。

切分阶段,将长文本切分为短片段时,需要注意切分后的短片段保持意义完整和相对独立。例如,不要把同一个表格内容切分到不同的片段。

索引阶段,在得到长度适合的若干短文本片段后,通过调用预先训练好的语言模型,计算得到短片段的Embedding并存储到向量数据库或向量索引系统中。

查询阶段,当用户输入Query时,利用上述训练好的语言模型计算得到Query的Embedding,在向量数据库或向量索引中召回近似向量,得到与Query相关或相似的短文本片段。

生成阶段,将Query与相关短文本片段同时作为大模型的输入,得到最终的结果。

LangChain 提供了支持上述流程的丰富的辅助工具。

需要注意的是,上述图中的方案采用的是向量召回的方案。除此之外,传统的倒排索引方案也可以作为召回的补充方案或者替代方案。Faiss可以作为向量召回的解决方案,Elasticsearch 支持倒排索引召回,也同时支持向量召回。

外部召回的方法在流程上与搜索增强的方法本质相同。这两者的主要区别在于,外部召回的方法主要是着眼于解决长文本难以一次完整输入大模型难点;搜索增强的方法重要着眼于解决大模型的信息不足或信息不能及时更新的问题,通过搜索引擎或外部知识库检索获取相关资料补充给大模型。

REALM(REtrieval Augmented Language Model)[2] 是一种将检索结果引入大模型,补充模型信息不足的方案,可以有效提升模型的理解效果。Memorizing Transformer [3] 也使用模型外部记忆作为补充。具体为,将长文本按512的长度切分,依次输入Transformer得到向量并存储于模型外部,在Transformer 靠近output的layer,将模型当前处理的向量与通过kNN 从外部获取的若干向量进行拼接。如下图所示:

3.2. 模型优化的一般方法

一般性的模型优化方法致力于降低模型计算时间复杂度或空间复杂度,这一研究方向一般称为 Model Compression/ Acceleration。虽然不是针对长文本建模的专门优化,但是一般性的模型优化节约出来的算力/存储空间可以用于更长的文本建模,对问题解决有一定的帮助。

常用优化方法包括量化(Quantization)、剪枝(Pruning)、蒸馏(Distillation)、参数共享(Weight Sharing)、矩阵分解(Factorization)。

除了算法层面优化,FlashAttention [4]在硬件层面的优化也大幅节约了计算时间,得到广泛应用。

除此之外,考虑 Transformer 之外的其他深度结构,在具体场景下也不失为一种明智选择。

这些方法在一定程度上能够降低计算的时间复杂度和空间复杂度,使得模型可以支持更长的输入数据。

3.3. Attention计算优化

专注于降低Transformer模型中 Attention 的计算复杂度和空间复杂度,能直接提升模型可以处理的上下文长度。这一研究方向属于 Efficient Transformers。

根据 Tay, Yi et al [5] 的综述,Efficient Transformers 可以划分为如下类型:

本文将对长文本建模直接有效的优化方法概括为如下类别:

3.3.1. Recurrence

这类方法以 Transformer-XL [6] 为代表。Vanilla Transformer 对超过固定长度的长文本,会按固定长度将长文本切分为若干片段再分别处理。这种做法没有考虑到切分后的片段之间存在的关联,导致上下文碎片化(context fragmentation)问题。如下图所示:

Transform-XL 借鉴了RNN的思想,提出 Segment-Level Recurrence 机制解决该问题,包括2个重要技术点,一是 Segment-Level Recurrence with State Reuse,二是 Relative Positional Encodings。

Segment-Level Recurrence with State Reuse 是指处理当前片段时,同时使用当前片段信息和缓存的上一片段信息。假设前一片段和当前处理片段为:

计算当前片段的状态时,同时使用当前片段信息 和 。

具体计算过程如下:

公式中的SG 代表不使用梯度,即处理当前片段时,前一片段的参数全部 fixed。 为了避免使用绝对位置时不同 Segment 会有相同的位置编码的问题,Transform-XL 提出了一种使用相对编码的方法(Relative Positional Encodings)。

绝对位置编码的计算公式展开后如下:

相对位置编码的计算如下(彩色部分是与绝对位置编码不同点):

Recurrent Memory Transformer [7] 也采用了 Segment-Level Recurrence 的机制处理长文。与 Transformer-XL 最大的区别有2点,1是memory参与梯度计算,2是memory 向量有 m 个,而Transformer-XL 是 mN 个。

3.3.2. Sparse Patterns

这类方法的核心思想是,通过稀疏模式减少一个token与其他位置token关联,或者说将 tokens之间的关联稀疏化,从而在不显著降低效果的前提下从整体上降低计算复杂度。典型的方法包括 Sparse Transformer、 Longformer、BigBird、LongNet。

Sparse Transformer 的 Attention 示意图(灰色/白色表示不存在关联,无需计算):

Longformer 的示意图:

BigBird 的示意图:

LongNet 示意图:

从上面的图例可以看出,Vanilla Transformer 采用的是 full self-attention,即假设每一个位置的token都可能与任意位置的上下文发生关联并进行Attention计算,这导致了复杂度为 O( ) 。这里N为序列长度,d为向量维度。

作为一种改进方法,Sparse Transformers [8] 放弃任意距离的token之间都会发生关联的假设,转而采用局部范围内 token之间关联密切,远距离的token之间的关联稀疏的假设,从而大幅降低计算复杂度。该方法又被称为 factorized attention,即将 full self-attention 分解成若干个小的 self attention。Sparse Transformer 包括 strided attention 和 fixed attention 这2种改进方案。

strided attention 方案利用 local self-attention 表示局部范围 token之间密切关联,利用 atrous/dilated self-attention 表示远距离 token之间的稀疏关联。一般用于处理图像、音乐等输入结构具有stride这种结构性的信息。

fixed attention 主要用于处理文本类信息,在表征局部范围的密切关联和远距离的稀疏关联方面存在不同的做法以适应文本类信息的结构特点。利用 local self-attention 表示局部范围token之间的关联,在特定位置保存局部信息的汇总并允许后续的token 都能够与这些汇总信息发生关联。也可以认为 fixed attention方案首先将将长序列切分为若干短序列,短序列内部token相互关联,短序列的整体信息保存在特定位置用于远距离的关联。

Sparse Transformer 将序列的切分长度设置为 。最终的计算复杂度为O ( )。这里N为序列长度,d为向量维度。

LongFormer[9] 与 Sparse Transformers 非常相似,可以看做是 strided attention 和 fixed attention 的融合。主要区别点包括:设定slide windows 大小为 w;通过 dilated sliding windows 的不同设置,可以更有效地覆盖局部或者更长距离的依赖关系;在指定位置增加 global attention 学习用于特定任务的表征(例如分类任务中的 [CLS] token)。在模型的浅层使用相对较小的slide windows,在深层使用较大的slide windows,这样随着模型深度增加逐步增加模型的感受野(receptive fields),使得模型浅层专注于局部模式,上层专注于全局模型。 LongFormer 的计算复杂度为 O ( )。 这里N为序列长度,k为设定的窗口大小,d为向量维度。

BigBird [10] 在LongFormer的基础上,增加 random attention。BigBird 的计算复杂度与LongFormer为同一量级。

LongNet [11] 采用不同尺寸的 dilated attention,满足了tokens近距离依赖密切,远距离依赖稀疏的假设。其复杂度为 O ( ) 。LongNet 通过实验证明了可处理的输入长度达到1B tokens。如下图所示:

Reformer [12] 提出了一种 LSH Attention,即利用 locality sensitive hashing (LSH) 方法对 Attention 中的 Query 和 Key 计算 hash值,利用LSH特点可以将相似Query和Key聚合到同一分桶,在Attention计算时只需要对在同一分桶的Query 和Key进行计算。

Reformer 的计算复杂度为 O( )。

3.3.3. Low-Rank / Kernels

Linformer [13] 通过实验验证了Transformer 的 self-attention 矩阵是低秩的(low-rank),并进行了理论证明。

图左是self-attention矩阵的奇异值分解的频谱分析(n=512)。图右是热力图。从图中可以看出呈现出明显的长尾分布,因此可以推断self-attention矩阵信息可以通过少量最大的奇异值恢复得到。 基于上述分析,Linformer 通过低秩矩阵代替原先的计算矩阵,从而使得计算复杂度降低到 O ( ) 。

Low-Rank Transformer [14] 通过 low-rank 降低矩阵计算复杂度,将其应用于 attention 计算,利用Linear Encoder-Decoder 代替单独的线性层计算,以及应用于Feed-Forward 层。

Performer [15] [16] 提出 FAVOR 和 FAVOR+ 算法。核心做法是调整传统的 attention计算过程,将Q 和 K 直接相乘的注意力矩阵进行分解,通过核函数(kernel function)得到Query和Key的近似表示 Q' 和 K',先将 K' 和 V 相乘,再与Q' 相乘。Performer 对上述变换过程作了严格证明。 如下图所示:

Performer 的计算复杂度降低到 O ( ) 。

3.3.4. Memory / Downsampling

Set Transformer [17] 提出利用 Transformer 解决输入顺序无关的问题,即输入数据为set-structured data。Set Transformer 提出了 Set Attention Block (SAB),与 Vanilla Transformer 的 attention 结构显著不同的地方是 SAB 丢弃了位置编码以及 dropout。SAB 的计算复杂度仍然是 O ( ) ,为了降低复杂度,Set Transformer 进一步提出Induced Set Attention Block (ISAB),新增可训练的 m d-维度的 inducing points 向量 。 inducing points 向量又被解释为 memory 。

与低秩投影或自编码器模拟类似,计算时将I转换为H后再进行计算。由于 ,因此 ISAB 计算复杂度降低到 O ( ) 。

Perceiver[18] [19] 利用 cross-attention 将较长的输入向量映射到较短的向量上,如下图所示,Byte array 是原始输入向量,M代表序列长度,C代表通道数;Latent array 是映射后的向量,N是设定的长度,D代表通道数;通过cross-attention 映射后的向量作为 Query,原始向量作为 Key 和 Value,由于 ,计算量显著降低。

Perceiver 的计算复杂度为 O ( ) 。

3.4. 未涉及的点

对长文本建模方案的分析未涉及到位置编码,此外,Multi query attention/Group query attention 也可以有效减少attention计算。这部分内容后续文章再作分析。

4. 总结分析:上下文越长越好吗

大模型可以支持更长的上下文输入与模型效果更好之间并不能直接画上等号。Nelson F. Liu etc [20] 和 Szymon Tworkowski etc [21] 均表示,上下文输入过长会导致注意力分散问题(distraction issue)。

Nelson F. Liu etc (2023) 通过多文档问答、Key-Value 检索2个任务说明了,模型能够处理的上下文长度不是真正的关键点,更重要的是模型对上下文内容的使用,即 how well the language models use longer context。

Szymon Tworkowski etc (2023) 说明过长的上下文会使得相关信息的占比显著下降,加剧注意力分散。针对这一问题的解决方案称为 Focused Transformer ,即通过对比学习提升 memory attention layer 精准定位到相关信息的能力,解决方案应用于 Open LLaMA 得到了支持更长上下文的 LongLLaMA。LongLLaMA 支持 256k 上下文。

通过上述问题和方案的分析可知,大模型长文本建模目前还没有一个统一的解决方案,造成困扰的原因正是源于 Vanilla Transformer 自身的结构。现有解决方案引入更多偏置,解决思路上与早期DNN、CNN的各种变体也有着千丝万缕的联系,或者说有着异曲同工之妙。

本文为个人总结的观点,如有谬误,敬请指正。

参考文献

[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[2] Guu, Kelvin, et al. "Retrieval augmented language model pre-training." International conference on machine learning . PMLR, 2020.

[3] Wu, Yuhuai, et al. "Memorizing transformers." arXiv preprint arXiv :2203.08913 (2022).

[4] Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.

[5] Tay, Yi et al. “Efficient Transformers: A Survey.” ACM Computing Surveys 55 (2020): 1 - 28.

[6] Dai, Zihang, et al. "Transformer-xl: Attentive language models beyond a fixed-length context." arXiv preprint arXiv :1901.02860 (2019).

[7] Bulatov, Aydar, Yury Kuratov, and Mikhail Burtsev. "Recurrent memory transformer." Advances in Neural Information Processing Systems 35 (2022): 11079-11091.

[8] Child, Rewon, et al. "Generating long sequences with sparse transformers." arXiv preprint arXiv :1904.10509 (2019).

[9] Beltagy, Iz, Matthew E. Peters, and Arman Cohan. "Longformer: The long-document transformer." arXiv preprint arXiv :2004.05150 (2020).

[10] Singh, Arjun, et al. "Bigbird: A large-scale 3d database of object instances." 2014 IEEE international conference on robotics and automation (ICRA) . IEEE, 2014.

[11] Ding, Jiayu, et al. "LongNet: Scaling Transformers to 1,000,000,000 Tokens." arXiv preprint arXiv :2307.02486 (2023).

[12] Kitaev, Nikita, Łukasz Kaiser, and Anselm Levskaya. "Reformer: The efficient transformer." arXiv preprint arXiv :2001.04451 (2020).

[13] Wang, Sinong, et al. "Linformer: Self-attention with linear complexity." arXiv preprint arXiv :2006.04768 (2020).

[14] Winata, Genta Indra, et al. "Lightweight and efficient end-to-end speech recognition using low-rank transformer." ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) . IEEE, 2020.

[15] Choromanski, Krzysztof Marcin, et al. "Rethinking Attention with Performers." International Conference on Learning Representations . 2020.

[16] Choromanski, Krzysztof, et al. "Masked language modeling for proteins via linearly scalable long-context transformers." arXiv preprint arXiv :2006.03555 (2020).

[17] Lee, Juho, et al. "Set transformer: A framework for attention-based permutation-invariant neural networks." International conference on machine learning . PMLR, 2019.

[18] Jaegle, Andrew, et al. "Perceiver: General perception with iterative attention." International conference on machine learning . PMLR, 2021.

[19] Jaegle, Andrew, et al. "Perceiver IO: A General Architecture for Structured Inputs Outputs." International Conference on Learning Representations . 2021.

[20] Liu, Nelson F., et al. "Lost in the Middle: How Language Models Use Long Contexts." arXiv preprint arXiv :2307.03172 (2023).

[21] Tworkowski, Szymon, et al. "Focused Transformer: Contrastive Training for Context Scaling." arXiv preprint arXiv :2307.03170 (2023).

免责声明:本文章如果文章侵权,请联系我们处理,本站仅提供信息存储空间服务如因作品内容、版权和其他问题请于本站联系