Published on

forward 方法实现(Embedding)

下面的代码摘自: candl-nn(embedding.rs)

impl crate::Module for Embedding {
    fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
        let mut final_dims = indexes.dims().to_vec();
        final_dims.push(self.hidden_size);
        let indexes = indexes.flatten_all()?;
        let values = self.embeddings.index_select(&indexes, 0)?;
        let values = values.reshape(final_dims)?;
        Ok(values)
    }
}

1. 先看示例

假设我们有以下数据:

  • 词表 (vocab_size = 4):4 个单词(ID 为 0 到 3)
  • 嵌入维度 (hidden_size = 3):每个单词用 3 维向量表示
  • 嵌入矩阵 (self.embeddings)
    [
        [0.1, 0.2, 0.3],  # ID=0 的向量(比如对应单词"猫")
        [0.4, 0.5, 0.6],  # ID=1(比如"狗")
        [0.7, 0.8, 0.9],  # ID=2(比如"鸟")
        [1.0, 1.1, 1.2],  # ID=3(比如"鱼")
    ]
    
  • 输入 indexesTensor([[0, 2], [1, 3]])(2 个句子,每个句子 2 个单词)

2. 逐行代码解析

步骤 1:获取输入张量的形状

let mut final_dims = indexes.dims().to_vec();
  • indexes.dims() 获取输入张量的形状 → [2, 2](2 行 2 列)
  • final_dims 变成 [2, 2](稍后会用到)

步骤 2:添加嵌入维度

final_dims.push(self.hidden_size);
  • self.hidden_size = 3(嵌入维度)
  • final_dims 变为 [2, 2, 3](目标输出形状)

步骤 3:展平输入索引

let indexes = indexes.flatten_all()?;
  • [[0, 2], [1, 3]] 展平为 [0, 2, 1, 3](变成一维列表)

步骤 4:核心查找操作

let values = self.embeddings.index_select(&indexes, 0)?;
  • index_select:根据 indexesself.embeddings 中选行
    • indexes = [0, 2, 1, 3],所以依次选取:
      • 第 0 行:[0.1, 0.2, 0.3]
      • 第 2 行:[0.7, 0.8, 0.9]
      • 第 1 行:[0.4, 0.5, 0.6]
      • 第 3 行:[1.0, 1.1, 1.2]
    • 结果 values 的形状是 (4, 3)(4 个 token,每个 3 维)

步骤 5:重塑为最终形状

let values = values.reshape(final_dims)?;
  • (4, 3) 的张量重塑为 (2, 2, 3)
    [
        [ [0.1, 0.2, 0.3], [0.7, 0.8, 0.9] ],  # 第一个句子
        [ [0.4, 0.5, 0.6], [1.0, 1.1, 1.2] ]   # 第二个句子
    ]
    

3. 可视化查找过程

输入 indexes (2x2):
[ [0, 2],
  [1, 3] ]

查找步骤:
1. 展平为 [0, 2, 1, 3]
2. 从嵌入矩阵按行选取:
   - 0[0.1, 0.2, 0.3]
   - 2[0.7, 0.8, 0.9]
   - 1[0.4, 0.5, 0.6]
   - 3[1.0, 1.1, 1.2]
3. 组合成形状 (4,3):
   [ [0.1, 0.2, 0.3],
     [0.7, 0.8, 0.9],
     [0.4, 0.5, 0.6],
     [1.0, 1.1, 1.2] ]
4. 重塑为 (2,2,3):
   [ [ [0.1, 0.2, 0.3], [0.7, 0.8, 0.9] ],
     [ [0.4, 0.5, 0.6], [1.0, 1.1, 1.2] ] ]

4. 为什么需要这么多步骤?

  • 展平 (flatten_all)
    为了统一处理任意形状的输入(比如 3D 输入 (batch, seq_len, ...))。
  • index_select
    Rust/Candle 的索引操作需要一维输入,不像 Python 可以直接多维索引。
  • 重塑 (reshape)
    恢复原始输入的批量结构和序列长度。

5. 边界情况验证

情况 1:ID 越界

如果 indexes 包含 5(但 vocab_size=4):

  • index_select 会报错:Index out of bounds(因为嵌入矩阵只有 4 行)

情况 2:空输入

如果 indexes 是空的:

  • flatten_all 会返回空张量,最终结果也是空的(符合预期)

6. 总结

  • 输入:任意形状的 token ID 张量(如 (batch, seq_len))。
  • 输出:相同形状但多一维(添加 embedding_dim)。
  • 关键操作
    index_select 是核心,相当于 Python 的 embedding_matrix[indexes]

就像查字典:

  1. 把要查的单词列表(indexes)摊平。
  2. 按顺序从字典(embeddings)里找每个单词的解释。
  3. 把解释按原来的句子顺序重新排列好。

THE END