Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96144
License: OTHER
1
""" word2vec skip-gram model with NCE loss and
2
code to visualize the embeddings on TensorBoard
3
CS 20: "TensorFlow for Deep Learning Research"
4
cs20.stanford.edu
5
Chip Huyen ([email protected])
6
Lecture 04
7
"""
8
9
import os
10
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
11
12
import numpy as np
13
from tensorflow.contrib.tensorboard.plugins import projector
14
import tensorflow as tf
15
16
import utils
17
import word2vec_utils
18
19
# Model hyperparameters
20
VOCAB_SIZE = 50000
21
BATCH_SIZE = 128
22
EMBED_SIZE = 128 # dimension of the word embedding vectors
23
SKIP_WINDOW = 1 # the context window
24
NUM_SAMPLED = 64 # number of negative examples to sample
25
LEARNING_RATE = 1.0
26
NUM_TRAIN_STEPS = 100000
27
VISUAL_FLD = 'visualization'
28
SKIP_STEP = 5000
29
30
# Parameters for downloading data
31
DOWNLOAD_URL = 'http://mattmahoney.net/dc/text8.zip'
32
EXPECTED_BYTES = 31344016
33
NUM_VISUALIZE = 3000 # number of tokens to visualize
34
35
class SkipGramModel:
36
""" Build the graph for word2vec model """
37
def __init__(self, dataset, vocab_size, embed_size, batch_size, num_sampled, learning_rate):
38
self.vocab_size = vocab_size
39
self.embed_size = embed_size
40
self.batch_size = batch_size
41
self.num_sampled = num_sampled
42
self.lr = learning_rate
43
self.global_step = tf.get_variable('global_step', initializer=tf.constant(0), trainable=False)
44
self.skip_step = SKIP_STEP
45
self.dataset = dataset
46
47
def _import_data(self):
48
""" Step 1: import data
49
"""
50
with tf.name_scope('data'):
51
self.iterator = self.dataset.make_initializable_iterator()
52
self.center_words, self.target_words = self.iterator.get_next()
53
54
def _create_embedding(self):
55
""" Step 2 + 3: define weights and embedding lookup.
56
In word2vec, it's actually the weights that we care about
57
"""
58
with tf.name_scope('embed'):
59
self.embed_matrix = tf.get_variable('embed_matrix',
60
shape=[self.vocab_size, self.embed_size],
61
initializer=tf.random_uniform_initializer())
62
self.embed = tf.nn.embedding_lookup(self.embed_matrix, self.center_words, name='embedding')
63
64
def _create_loss(self):
65
""" Step 4: define the loss function """
66
with tf.name_scope('loss'):
67
# construct variables for NCE loss
68
nce_weight = tf.get_variable('nce_weight',
69
shape=[self.vocab_size, self.embed_size],
70
initializer=tf.truncated_normal_initializer(stddev=1.0 / (self.embed_size ** 0.5)))
71
nce_bias = tf.get_variable('nce_bias', initializer=tf.zeros([VOCAB_SIZE]))
72
73
# define loss function to be NCE loss function
74
self.loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weight,
75
biases=nce_bias,
76
labels=self.target_words,
77
inputs=self.embed,
78
num_sampled=self.num_sampled,
79
num_classes=self.vocab_size), name='loss')
80
def _create_optimizer(self):
81
""" Step 5: define optimizer """
82
self.optimizer = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss,
83
global_step=self.global_step)
84
85
def _create_summaries(self):
86
with tf.name_scope('summaries'):
87
tf.summary.scalar('loss', self.loss)
88
tf.summary.histogram('histogram loss', self.loss)
89
# because you have several summaries, we should merge them all
90
# into one op to make it easier to manage
91
self.summary_op = tf.summary.merge_all()
92
93
def build_graph(self):
94
""" Build the graph for our model """
95
self._import_data()
96
self._create_embedding()
97
self._create_loss()
98
self._create_optimizer()
99
self._create_summaries()
100
101
def train(self, num_train_steps):
102
saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias
103
104
initial_step = 0
105
utils.safe_mkdir('checkpoints')
106
with tf.Session() as sess:
107
sess.run(self.iterator.initializer)
108
sess.run(tf.global_variables_initializer())
109
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
110
111
# if that checkpoint exists, restore from checkpoint
112
if ckpt and ckpt.model_checkpoint_path:
113
saver.restore(sess, ckpt.model_checkpoint_path)
114
115
total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
116
writer = tf.summary.FileWriter('graphs/word2vec/lr' + str(self.lr), sess.graph)
117
initial_step = self.global_step.eval()
118
119
for index in range(initial_step, initial_step + num_train_steps):
120
try:
121
loss_batch, _, summary = sess.run([self.loss, self.optimizer, self.summary_op])
122
writer.add_summary(summary, global_step=index)
123
total_loss += loss_batch
124
if (index + 1) % self.skip_step == 0:
125
print('Average loss at step {}: {:5.1f}'.format(index, total_loss / self.skip_step))
126
total_loss = 0.0
127
saver.save(sess, 'checkpoints/skip-gram', index)
128
except tf.errors.OutOfRangeError:
129
sess.run(self.iterator.initializer)
130
writer.close()
131
132
def visualize(self, visual_fld, num_visualize):
133
""" run "'tensorboard --logdir='visualization'" to see the embeddings """
134
135
# create the list of num_variable most common words to visualize
136
word2vec_utils.most_common_words(visual_fld, num_visualize)
137
138
saver = tf.train.Saver()
139
with tf.Session() as sess:
140
sess.run(tf.global_variables_initializer())
141
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
142
143
# if that checkpoint exists, restore from checkpoint
144
if ckpt and ckpt.model_checkpoint_path:
145
saver.restore(sess, ckpt.model_checkpoint_path)
146
147
final_embed_matrix = sess.run(self.embed_matrix)
148
149
# you have to store embeddings in a new variable
150
embedding_var = tf.Variable(final_embed_matrix[:num_visualize], name='embedding')
151
sess.run(embedding_var.initializer)
152
153
config = projector.ProjectorConfig()
154
summary_writer = tf.summary.FileWriter(visual_fld)
155
156
# add embedding to the config file
157
embedding = config.embeddings.add()
158
embedding.tensor_name = embedding_var.name
159
160
# link this tensor to its metadata file, in this case the first NUM_VISUALIZE words of vocab
161
embedding.metadata_path = 'vocab_' + str(num_visualize) + '.tsv'
162
163
# saves a configuration file that TensorBoard will read during startup.
164
projector.visualize_embeddings(summary_writer, config)
165
saver_embed = tf.train.Saver([embedding_var])
166
saver_embed.save(sess, os.path.join(visual_fld, 'model.ckpt'), 1)
167
168
def gen():
169
yield from word2vec_utils.batch_gen(DOWNLOAD_URL, EXPECTED_BYTES, VOCAB_SIZE,
170
BATCH_SIZE, SKIP_WINDOW, VISUAL_FLD)
171
172
def main():
173
dataset = tf.data.Dataset.from_generator(gen,
174
(tf.int32, tf.int32),
175
(tf.TensorShape([BATCH_SIZE]), tf.TensorShape([BATCH_SIZE, 1])))
176
model = SkipGramModel(dataset, VOCAB_SIZE, EMBED_SIZE, BATCH_SIZE, NUM_SAMPLED, LEARNING_RATE)
177
model.build_graph()
178
model.train(NUM_TRAIN_STEPS)
179
model.visualize(VISUAL_FLD, NUM_VISUALIZE)
180
181
if __name__ == '__main__':
182
main()
183