import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np
import random
USE_CUDA = torch.cuda.is_available()
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
torch.cuda.manual_seed(1)
K = 100
C = 3
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100
LOG_FILE = "word-embedding.log"
def word_tokenize(text):
return text.split()
# with open("/Users/xx/Desktop/text8/text8.train.txt", "r") as fin:
# text = fin.read()
# Counter text
text = [w for w in word_tokenize(text.lower())]
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
vocab['<unk>'] = len(text) - np.sum(list(vocab.values()))
# word idx --> idx
idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}
print("idx_to_word = ", idx_to_word[:10])
# word count & freq
word_count = np.array([count for count in vocab.values()], dtype=np.float32)
word_freq = word_count / np.sum(word_count)
word_freq = word_freq ** (3. / 4.)
word_freq = word_count / np.sum(word_count)
VOCAB_SIZE = len(idx_to_word)
class Embedding(nn.Module):
def __init__(self, text, word_to_idx, idx_to_word, word_freq, word_count):
super(Embedding, self).__init__()
self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
self.text_encoded = torch.Tensor(self.text_encoded).long()
self.word_to_idx = word_to_idx
self.idx_to_word = idx_to_word
self.word_freqs = torch.Tensor(word_freq)
self.word_counts = torch.Tensor(word_count)
def __len__(self):
return len(self.text_encoded)
def __getitem__(self, idx):
center_word = self.text_encoded[idx]
pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+1+C))
pos_words = self.text_encoded[pos_indices]
neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
return center_word, pos_words, neg_words
print('dataset...')
dataset = Embedding(text, word_to_idx, idx_to_word, word_freq, word_count)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
a = 1
class Model(nn.Module):
def __init__(self, vocab_size, embed_size):
super(Model, self).__init__()
self.vocab_size = vocab_size
self.embed_size = embed_size
initrange = 0.5 / self.embed_size
self.out_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
self.out_embed.weight.data.uniform_(-initrange, initrange)
self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
self.in_embed.weight.data.uniform_(-initrange, initrange)
def forward(self, input_labels, pos_labels, neg_labels):
batch_size = input_labels.size(0)
input_embedding = self.in_embed(input_labels)
pos_embedding = self.out_embed(pos_labels)
neg_embedding = self.out_embed(neg_labels)
log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze()
log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze()
log_pos = F.logsigmoid(log_pos).sum(1)
log_neg = F.logsigmoid(log_neg).sum(1)
loss = log_pos + log_neg
return -loss
def input_embedding(self):
return self.in_embed.weight.data.cpu().num()
model = Model(VOCAB_SIZE, EMBEDDING_SIZE)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
input_labels = input_labels.long()
pos_labels = pos_labels.long()
neg_labels = neg_labels.long()
if USE_CUDA:
input_labels = input_labels.cuda()
pos_labels = pos_labels.cuda()
neg_labels = neg_labels.cuda()
optimizer.zero_grad()
loss = model(input_labels, pos_labels, neg_labels).mean()
loss.backward()
optimizer.step()
if i % 20 == 0:
with open(LOG_FILE, "a") as fout:
fout.write("epoch: {}, iter: {}, loss: {}\n".format(e, i, loss.item()))
print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))
代码结果
epoch: 0, iter: 0, loss: 420.0480041503906
epoch: 0, iter: 20, loss: 329.0500793457031
epoch: 0, iter: 40, loss: 283.3576965332031
epoch: 0, iter: 60, loss: 247.8422088623047
epoch: 0, iter: 80, loss: 230.79153442382812
epoch: 0, iter: 100, loss: 217.08267211914062
epoch: 0, iter: 120, loss: 192.9191131591797
epoch: 0, iter: 140, loss: 170.9046630859375
epoch: 0, iter: 160, loss: 181.2309112548828
epoch: 0, iter: 180, loss: 166.06568908691406
epoch: 0, iter: 200, loss: 162.77157592773438
epoch: 0, iter: 220, loss: 188.29168701171875
epoch: 0, iter: 240, loss: 171.34072875976562
epoch: 0, iter: 260, loss: 157.67115783691406
epoch: 0, iter: 280, loss: 132.0842742919922
epoch: 0, iter: 300, loss: 146.44769287109375
epoch: 0, iter: 320, loss: 133.27023315429688
epoch: 0, iter: 340, loss: 142.2010955810547
epoch: 0, iter: 360, loss: 121.07867431640625
epoch: 0, iter: 380, loss: 142.3241424560547
epoch: 0, iter: 400, loss: 116.6792221069336
epoch: 0, iter: 420, loss: 107.60456085205078
epoch: 0, iter: 440, loss: 114.59528350830078
epoch: 0, iter: 460, loss: 116.83323669433594
epoch: 0, iter: 480, loss: 110.5595474243164
epoch: 0, iter: 500, loss: 101.20355224609375
epoch: 0, iter: 520, loss: 100.99398040771484
epoch: 0, iter: 540, loss: 100.62711334228516
epoch: 0, iter: 560, loss: 98.46331787109375
epoch: 0, iter: 580, loss: 101.14088439941406
epoch: 0, iter: 600, loss: 98.90962982177734
epoch: 0, iter: 620, loss: 89.2414321899414
epoch: 0, iter: 640, loss: 89.59358978271484
epoch: 0, iter: 660, loss: 101.16934204101562
epoch: 0, iter: 680, loss: 90.46540069580078
epoch: 0, iter: 700, loss: 110.49874114990234
epoch: 0, iter: 720, loss: 91.24562072753906
epoch: 0, iter: 740, loss: 83.98448181152344
epoch: 0, iter: 760, loss: 83.93696594238281
epoch: 0, iter: 780, loss: 91.41669464111328
epoch: 0, iter: 800, loss: 70.84886169433594
epoch: 0, iter: 820, loss: 78.18423461914062
epoch: 0, iter: 840, loss: 77.8626937866211
epoch: 0, iter: 860, loss: 80.74642944335938
epoch: 0, iter: 880, loss: 74.10748291015625
epoch: 0, iter: 900, loss: 87.81755065917969
epoch: 0, iter: 920, loss: 74.84403991699219
epoch: 0, iter: 940, loss: 87.2343978881836
epoch: 0, iter: 960, loss: 67.5108871459961
epoch: 0, iter: 980, loss: 85.40575408935547
epoch: 0, iter: 1000, loss: 70.75318908691406
epoch: 0, iter: 1020, loss: 74.62250518798828
epoch: 0, iter: 1040, loss: 69.09252166748047
epoch: 0, iter: 1060, loss: 77.4391860961914
epoch: 0, iter: 1080, loss: 73.82855987548828
epoch: 0, iter: 1100, loss: 79.26576232910156
epoch: 0, iter: 1120, loss: 88.65275573730469
epoch: 0, iter: 1140, loss: 69.38176727294922
epoch: 0, iter: 1160, loss: 67.59727478027344
epoch: 0, iter: 1180, loss: 73.68869018554688
epoch: 0, iter: 1200, loss: 69.18522644042969
epoch: 0, iter: 1220, loss: 79.35118865966797
epoch: 0, iter: 1240, loss: 75.84447479248047
epoch: 0, iter: 1260, loss: 78.92072296142578
epoch: 0, iter: 1280, loss: 63.9658203125
epoch: 0, iter: 1300, loss: 73.59236145019531
参考论文:word2vec Parameter Learning Explained
参考资料(有些没看完):