torch.nn.Embedding
随机初始化词向量矩阵
:这种方式很容易理解,就是使用self.embedding =
torch.nn.Embedding
(vocab_size, embed_dim)
命令直接
随机生成个初始化的词向量矩阵
,此时的向量值
符合正态分布N(0,1)
,这里的
vocab_size是指词向量矩阵能表征的
词的个数
,这个数值即是词向量文件中词的数量加1(加1的原因是,如果某个词在词向量文件中不存在,则获取不到索引,也就无法在词向量矩阵中获取对应的向量,这时我们默认这个词的索引为0,即将词向量的第一行作为这个词的向量表征。使用预训练的词向量文件时,这个方法同样适用),
embed_dim是指表征每个词时,
向量的维度
(可自定义,如256)。对于
随机
初始化词向量矩阵的方式
,词向量文件的生成方式一般是将当前所有的文本数据(包括训练数据、验证数据、测试数据)进行切词,再对所有词进行聚合统计,保留词的数量大于某个阈值(比如3)的词,并进行索引编号(编号从1开始,0作为上面提到的不在词向量文件中的其他词的索引),进而生成词向量文件。顺便提一句,词向量矩阵的初始化的方式也有很多种,比如Xavier、Kaiming初始化方法。
使用预训练的词向量文件初始化词向量矩阵
:本质上,词向量矩阵的作用是实现文本的向量表征,因此,如何用更合适的向量表示文本,逐渐成为了一个热门研究方向。预训练的词向量文件便是其中的一个研究成果,如
通过word2vec、glove等预训练模型生成的词向量文件,通过大量的训练数据,来生成词的向量表征
。以word2vec为例,训练后生成的词向量文件是以离线配置文件的形式存在,可通过gensim工具包进行加载,具体命令是 wvmodel
=
gensim.models.KeyedVectors.load_word2vec_format
(word2vec_file,
binary=False, encoding='utf-8',
unicode_errors='ignore') ,加载后,可通过 wvmodel.key_to_index 获取词向量文件(要对词向量文件中的词索引进行重新编号,原索引从0开始,调整为从1开始,0作为不在词向量文件中的词的索引),通过 wvmodel.get_vector("xxx") 获取词向量文件中每个词对应的向量,将词向量文件中所有词对应的向量聚合在一起后(聚合的方式是,每个词的向量表征,按照词的索引,填充在词向量矩阵对应的位置),生成预训练词向量矩阵 weight,再通过 self.embedding
= torch.nn.Embedding.from_pretrained(weight,
freeze=False) 完成词向量矩阵的初始化,参数freeze的作用,是指明训练时是否更新词向量矩阵的权重值,True为不更新,默认为True,等同于 self.embedding.weight.requires_grad
= False)。
还有个细节需要介绍下,在获取到预训练的词向量文件后,由于预训练的词向量文件很大,因此在后续的训练过程中,可能会出现内存不足的错误,此时可对词向量文件及预训练词向量矩阵进行调整,具体来说,先对我们本身任务的所有文本数据进行切词统计,保留数量超过一定阈值的词,作为词向量文件(就是随机初始化词向量矩阵时,词向量文件的生成方法),再利用这个词向量文件,配合wvmodel.get_vector("xxx"),获取预训练词向量矩阵weight,最后进行后续的词向量矩阵初始化过程。这样操作之后,由于词向量文件中词的数量减少,词向量矩阵的行数减少,内存占用会随之减少很多。另外,生成词向量的预训练方法还有很多,参见【通俗易懂的词向量】。
个人理解:
nn.embedding就是一个
字典映射表
,比如它的大小是128,0~127每个位置都存储着一个长度为3的数组,那么我们外部输入的值可以通过index (0~127)映射到每个对应的数组上,所以不管外部的值是如何都能在该nn.embedding中找到对应的数组。想想
哈希表
,就很好理解了。
既然是映射表,那么外部的输入的值肯定不能超过最大长度,比如128,同时下限也是。
import
torch.nn as nn
embedding
= nn.Embedding(10, 3)
#
an Embedding module containing 10 tensors of size 3 10个张量,每个张量的维度为3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
#
a batch of 2 samples of 4 indices each 两个样本,每个样本有四个索引
e =
embedding(input)
print
(e)
输出的结果:
我们一步步理解代码:
首先,
embedding = nn.Embedding(10, 3)
即定义一个embedding模块,包含了一个长度为10的张量,每个张量的大小是3。举个例子,[-1.0556, -0.2404, -0.4578]就是一个tensor,那么如何取该tensor?使用
下标index
去取。
其次,
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
即输入一个我们需要embedding的变量,输入的每个值最终映射到张量空间中。
最后,我们发现输出e变成了[2, 4, 3]的张量。 说说怎么看张量的维度,从最外层的
[]
开始,计算里面的独立个体,发现是2;接着从第二维度的
[]
开始数,发现是4;依次类推就可以得到张量的维度是[2, 4, 3]。
我们看看embedding的weight:
embedding.weight
我们发现embedding.weight是个[10, 3]的向量,那么embedding.weight的值是怎么被我们input取到的呢? 比如index = 1,那么我们取[-1.0556, -0.2404, -0.4578]; index = 2, 取[ 1.3328, 2.5743, -0.7375]; index = 4, 取[-0.0584, -0.6458, 0.8236]。 这不就刚好对应了e的输入为1/2/4的值吗?只是我们把
输入1作为index去embedding.weight
取对应的值去填充新的张量e。
所以说,我们待输入的张量[[1,2,4,5],[4,3,2,9]],在经过nn.embedding后,从[2, 4]维度变换为[2, 4, 3],其实就是[2, 4]中的每个值作为索引去nn.embedding中取对应的权重。
练习1——改变embedding_dim
embedding = nn.Embedding(10, 4) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
很明显,当embedding是个[10, 4]的张量时,映射出的张量为[2, 4, 4]
练习2——index越界
embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,10]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
报错:IndexError: index out of range in self
输出会报错,那是因为我们的embedding的维度是[10, 3],所以index的取值从0~9,那么我们取10肯定就出现问题了。如果出现对应的问题时,就可以大致猜到输入的值越界了。
练习3——sequence长度不一致
embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
报错:ValueError: expected sequence of length 3 at dim 1 (got 4)
将第一维[1, 2, 4, 5]减去5变成[1,2,4],出现ValueError: expected sequence of length 3 at dim 1 (got 4)的问题,所以需要每个维度的长度都一致。
练习4——改变输入
embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[[1,2],[2,3],[4,5],[5,7]],[[4,5],[3,4],[2,3],[8,9]]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
https://blog.csdn.net/qq_39540454/article/details/115215056
https://zhuanlan.zhihu.com/p/647536930
https://www.cnblogs.com/emanlee/p/17455844.html
https://blog.csdn.net/qq_39439006/article/details/126760701