Published on

AI探秘-理解大模型注意力机制

Authors
  • avatar
    Name
    noodles
    每个人的花期不同,不必在乎别人比你提前拥有

目录


长序列建模问题

1. 信息传递衰减

  • 问题描述:在RNN等序列模型中,前序token通过"隐藏状态"向后传递信息
  • 衰减原因:梯度消失问题,信息在长距离传播中逐渐丢失
  • 影响:难以捕获长距离依赖关系,影响模型理解能力

2. 局部视野局限

  • CNN模型限制:通过堆叠卷积层可以扩大感受野,但存在以下问题:
    • 计算成本随层数指数增长
    • 全局关联信息缺失
    • 感受野扩大有限,难以处理超长序列

3. 计算效率瓶颈

  • 时间复杂度问题:
    • RNN:O(n) 顺序处理,无法并行化
    • CNN:需要大量层数才能捕获长距离依赖
    • 传统注意力:O(n²) 复杂度,序列长度翻倍计算量增加4倍
  • 内存消耗:注意力矩阵大小与序列长度平方成正比
  • 实际影响:处理长文档、长对话时速度显著下降

自注意力机制实现

什么是注意力机制

让模型通过大量数据自动学习如何关注重要信息

没有可训练权重的简单自注意力机制

没有训练权重的自注意力机制
  • 计算当前词元与其他词元的注意力分数(通过点积,度量向量相似度的一种方式)
  • 通过归一化得到注意力权重
  • 依次计算所有词元的上下文向量(context vector,可以理解为序列中包含了所有元素信息的嵌入向量)

缩放点积注意力(scaled dot-product attention)

什么是可训练权重(Q K V的作用)

可训练权重是指在模型训练过程中可以自动调整的参数,它们通过梯度下降算法不断优化,使模型能够学习到最优的表示。
Query(Q)查询矩阵: Q矩阵表示当前词元想要关注的内容特征

  • 例如:动词"sat"的Q矩阵学习关注主语"cat"和宾语"mat"的特征

Key(K)键矩阵: K矩阵表示每个词元提供的可被关注的特征标识

  • 例如:名词"cat"的K矩阵学习提供主语、动物等特征标识

Value(V)值矩阵: V矩阵表示每个词元要传递的具体信息内容

  • 例如:动词"sat"的V矩阵学习传递动作、位置等具体信息

缩放点积注意力实现过程

缩放点积注意力实现过程 1. 权重初始化
  • 随机初始化W_Q、W_K、W_V三个可训练权重矩阵

2. 前向传播

  • 使用当前权重计算注意力
  • 计算Q K V矩阵:Q = X @ W_Q, K = X @ W_K, V = X @ W_V
  • 计算注意力分数(attention_scores)
    • 通过查询向量Q与键向量K的转置点积
    • 除以√d_k进行缩放(防止梯度消失)
  • 计算注意力权重:使用softmax对attention_scores进行归一化
  • 计算上下文向量(context vector):attention_weights @ V

3. 损失计算

  • 通过交叉熵损失函数计算损失:loss = cross_entropy_loss(predicted, target)

4. 反向传播

  • 根据损失计算权重梯度:∂Loss/∂W_Q、∂Loss/∂W_K、∂Loss/∂W_V

5. 权重更新

  • 根据梯度更新权重:W_Q、W_K、W_V同时更新,Q、K、V自动同步更新

6. 迭代优化

  • 重复过程直到收敛
  • 检查损失函数是否收敛或达到预设目标

具体例子

输入句子: "The cat sat on the"
目标: 预测下一个词 "mat"

  1. 权重初始化: W_Q、W_K、W_V随机初始化
  2. 前向传播: 计算注意力,得到上下文向量
  3. 损失计算: 预测"dog"但目标是"mat",损失较高
  4. 反向传播: 计算梯度,指导权重调整方向
  5. 权重更新: 调整权重,让模型更关注"mat"
  6. 迭代优化: 重复过程,直到能正确预测"mat"

因果注意力(掩码注意力)

因果注意力

因果注意力机制

  • 只能基于当前词元及其之前的词元计算注意力
  • 通过掩码将未来词元的注意力分数设为0
  • 确保模型在预测时不会"偷看"未来信息

多头注意力机制

多头注意力的主要思想是多次(并行)运行注意力机制,每次使用不同的线性投影。这些投影是通过将输入数据(查询向量、键向量和值向量)乘以权重矩阵得到的。 多头注意力机制
通过多头注意力机制可以捕捉到不同类型的特征信息,提高模型对复杂模式的理解能力

参考

Transformer模型详解(图解最完整版) 李宏毅 Transformer讲解