- Published on
forward 方法实现(Embedding)
下面的代码摘自: candl-nn(embedding.rs)
impl crate::Module for Embedding {
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = self.embeddings.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}
1. 先看示例
假设我们有以下数据:
- 词表 (
vocab_size = 4
):4 个单词(ID 为 0 到 3) - 嵌入维度 (
hidden_size = 3
):每个单词用 3 维向量表示 - 嵌入矩阵 (
self.embeddings
):[ [0.1, 0.2, 0.3], # ID=0 的向量(比如对应单词"猫") [0.4, 0.5, 0.6], # ID=1(比如"狗") [0.7, 0.8, 0.9], # ID=2(比如"鸟") [1.0, 1.1, 1.2], # ID=3(比如"鱼") ]
- 输入
indexes
:Tensor([[0, 2], [1, 3]])
(2 个句子,每个句子 2 个单词)
2. 逐行代码解析
步骤 1:获取输入张量的形状
let mut final_dims = indexes.dims().to_vec();
indexes.dims()
获取输入张量的形状 →[2, 2]
(2 行 2 列)final_dims
变成[2, 2]
(稍后会用到)
步骤 2:添加嵌入维度
final_dims.push(self.hidden_size);
self.hidden_size = 3
(嵌入维度)final_dims
变为[2, 2, 3]
(目标输出形状)
步骤 3:展平输入索引
let indexes = indexes.flatten_all()?;
- 将
[[0, 2], [1, 3]]
展平为[0, 2, 1, 3]
(变成一维列表)
步骤 4:核心查找操作
let values = self.embeddings.index_select(&indexes, 0)?;
index_select
:根据indexes
从self.embeddings
中选行indexes = [0, 2, 1, 3]
,所以依次选取:- 第 0 行:
[0.1, 0.2, 0.3]
- 第 2 行:
[0.7, 0.8, 0.9]
- 第 1 行:
[0.4, 0.5, 0.6]
- 第 3 行:
[1.0, 1.1, 1.2]
- 第 0 行:
- 结果
values
的形状是(4, 3)
(4 个 token,每个 3 维)
步骤 5:重塑为最终形状
let values = values.reshape(final_dims)?;
- 将
(4, 3)
的张量重塑为(2, 2, 3)
:[ [ [0.1, 0.2, 0.3], [0.7, 0.8, 0.9] ], # 第一个句子 [ [0.4, 0.5, 0.6], [1.0, 1.1, 1.2] ] # 第二个句子 ]
3. 可视化查找过程
输入 indexes (2x2):
[ [0, 2],
[1, 3] ]
查找步骤:
1. 展平为 [0, 2, 1, 3]
2. 从嵌入矩阵按行选取:
- 0 → [0.1, 0.2, 0.3]
- 2 → [0.7, 0.8, 0.9]
- 1 → [0.4, 0.5, 0.6]
- 3 → [1.0, 1.1, 1.2]
3. 组合成形状 (4,3):
[ [0.1, 0.2, 0.3],
[0.7, 0.8, 0.9],
[0.4, 0.5, 0.6],
[1.0, 1.1, 1.2] ]
4. 重塑为 (2,2,3):
[ [ [0.1, 0.2, 0.3], [0.7, 0.8, 0.9] ],
[ [0.4, 0.5, 0.6], [1.0, 1.1, 1.2] ] ]
4. 为什么需要这么多步骤?
- 展平 (
flatten_all
):
为了统一处理任意形状的输入(比如 3D 输入(batch, seq_len, ...)
)。 index_select
:
Rust/Candle 的索引操作需要一维输入,不像 Python 可以直接多维索引。- 重塑 (
reshape
):
恢复原始输入的批量结构和序列长度。
5. 边界情况验证
情况 1:ID 越界
如果 indexes
包含 5
(但 vocab_size=4
):
index_select
会报错:Index out of bounds
(因为嵌入矩阵只有 4 行)
情况 2:空输入
如果 indexes
是空的:
flatten_all
会返回空张量,最终结果也是空的(符合预期)
6. 总结
- 输入:任意形状的 token ID 张量(如
(batch, seq_len)
)。 - 输出:相同形状但多一维(添加
embedding_dim
)。 - 关键操作:
index_select
是核心,相当于 Python 的embedding_matrix[indexes]
。
就像查字典:
- 把要查的单词列表(
indexes
)摊平。 - 按顺序从字典(
embeddings
)里找每个单词的解释。 - 把解释按原来的句子顺序重新排列好。
THE END