
假设使用Glove.6B.300d (400k的vocab_size),后续有时间会把下面这个写成一个可传参function,加到tool_box.py里面
import pickle
import bcolz
import numpy as np
root_dir = embed_path.rsplit(".",1)[0]+".dat"
out_dir_word = embed_path.rsplit(".",1)[0]+"_words.pkl"
out_dir_idx = embed_path.rsplit(".",1)[0]+"_idx.pkl"
if not all([os.path.exists(root_dir),os.path.exists(out_dir_word),os.path.exists(out_dir_idx)]):
## process and cache glove ===========================================
words = []
idx = 0
word2idx = {}
vectors = bcolz.carray(np.zeros(1), rootdir=root_dir, mode='w')
with open(os.path.join(embed_path),"rb") as f:
for l in f:
line = l.decode().split()
word = line[0]
words.append(word)
word2idx[word] = idx
idx += 1
vect = np.array(line[1:]).astype(np.float)
vectors.append(vect)
vectors = bcolz.carray(vectors[1:].reshape((400000, 300)), rootdir=root_dir, mode='w')
vectors.flush()
pickle.dump(words, open(out_dir_word, 'wb'))
pickle.dump(word2idx, open(out_dir_idx, 'wb'))
print("dump word/idx at {}".format(embed_path.rsplit("/",1)[0]))
## =======================================================
## load glove
vectors = bcolz.open(root_dir)[:]
words = pickle.load(open(embed_path.rsplit(".",1)[0]+"_words.pkl", 'rb'))
word2idx = pickle.load(open(embed_path.rsplit(".",1)[0]+"_idx.pkl", 'rb'))
weights_matrix = np.zeros((400002, 300)) ## unk & pad ## default fix
weights_matrix[1] = np.random.normal(scale=0.6, size=(300, ))
weights_matrix[2:,:] = vectors
# weights_matrix = torch.FloatTensor(weights_matrix)
pad_idx,unk_idx = 0,1
self.embed = Embedding(400002, 300,padding_idx=pad_idx) ## fix the pda_dix
self.embed.load_state_dict({'weight': weights_matrix})