栏目分类:
子分类:
返回
终身学习网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
终身学习网 > IT > 软件开发 > 后端开发 > Python

pytorch 使用pre-trained词向量 (作个人记录)

Python 更新时间:发布时间: 百科书网 趣学号

假设使用Glove.6B.300d (400k的vocab_size),后续有时间会把下面这个写成一个可传参function,加到tool_box.py里面

import pickle
import bcolz
import numpy as np

root_dir = embed_path.rsplit(".",1)[0]+".dat"
out_dir_word = embed_path.rsplit(".",1)[0]+"_words.pkl"
out_dir_idx = embed_path.rsplit(".",1)[0]+"_idx.pkl"
if not all([os.path.exists(root_dir),os.path.exists(out_dir_word),os.path.exists(out_dir_idx)]):
    ## process and cache glove ===========================================
    words = []
    idx = 0
    word2idx = {}    
    vectors = bcolz.carray(np.zeros(1), rootdir=root_dir, mode='w')
    with open(os.path.join(embed_path),"rb") as f:
        for l in f:
            line = l.decode().split()
            word = line[0]
            words.append(word)
            word2idx[word] = idx
            idx += 1
            vect = np.array(line[1:]).astype(np.float)
            vectors.append(vect)
    vectors = bcolz.carray(vectors[1:].reshape((400000, 300)), rootdir=root_dir, mode='w')
    vectors.flush()
    pickle.dump(words, open(out_dir_word, 'wb'))
    pickle.dump(word2idx, open(out_dir_idx, 'wb'))
    print("dump word/idx at {}".format(embed_path.rsplit("/",1)[0]))
    ## =======================================================
## load glove
vectors = bcolz.open(root_dir)[:]
words = pickle.load(open(embed_path.rsplit(".",1)[0]+"_words.pkl", 'rb'))
word2idx = pickle.load(open(embed_path.rsplit(".",1)[0]+"_idx.pkl", 'rb'))

weights_matrix = np.zeros((400002, 300))  ## unk & pad  ## default fix
weights_matrix[1] = np.random.normal(scale=0.6, size=(300, ))
weights_matrix[2:,:] = vectors
# weights_matrix = torch.FloatTensor(weights_matrix)

pad_idx,unk_idx = 0,1
self.embed = Embedding(400002, 300,padding_idx=pad_idx)  ## fix the pda_dix
self.embed.load_state_dict({'weight': weights_matrix})
转载请注明:文章转载自 www.051e.com
本文地址:http://www.051e.com/it/461510.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 ©2023-2025 051e.com

ICP备案号:京ICP备12030808号