Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96144
License: OTHER
1
""" A clean, no_frills character-level generative language model.
2
3
CS 20: "TensorFlow for Deep Learning Research"
4
cs20.stanford.edu
5
Danijar Hafner ([email protected])
6
& Chip Huyen ([email protected])
7
Lecture 11
8
"""
9
import os
10
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
11
import random
12
import sys
13
sys.path.append('..')
14
import time
15
16
import tensorflow as tf
17
18
import utils
19
20
def vocab_encode(text, vocab):
21
return [vocab.index(x) + 1 for x in text if x in vocab]
22
23
def vocab_decode(array, vocab):
24
return ''.join([vocab[x - 1] for x in array])
25
26
def read_data(filename, vocab, window, overlap):
27
lines = [line.strip() for line in open(filename, 'r').readlines()]
28
while True:
29
random.shuffle(lines)
30
31
for text in lines:
32
text = vocab_encode(text, vocab)
33
for start in range(0, len(text) - window, overlap):
34
chunk = text[start: start + window]
35
chunk += [0] * (window - len(chunk))
36
yield chunk
37
38
def read_batch(stream, batch_size):
39
batch = []
40
for element in stream:
41
batch.append(element)
42
if len(batch) == batch_size:
43
yield batch
44
batch = []
45
yield batch
46
47
class CharRNN(object):
48
def __init__(self, model):
49
self.model = model
50
self.path = 'data/' + model + '.txt'
51
if 'trump' in model:
52
self.vocab = ("$%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
53
" '\"_abcdefghijklmnopqrstuvwxyz{|}@#➡📈")
54
else:
55
self.vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
56
"\\^_abcdefghijklmnopqrstuvwxyz{|}")
57
58
self.seq = tf.placeholder(tf.int32, [None, None])
59
self.temp = tf.constant(1.5)
60
self.hidden_sizes = [128, 256]
61
self.batch_size = 64
62
self.lr = 0.0003
63
self.skip_step = 1
64
self.num_steps = 50 # for RNN unrolled
65
self.len_generated = 200
66
self.gstep = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
67
68
def create_rnn(self, seq):
69
layers = [tf.nn.rnn_cell.GRUCell(size) for size in self.hidden_sizes]
70
cells = tf.nn.rnn_cell.MultiRNNCell(layers)
71
batch = tf.shape(seq)[0]
72
zero_states = cells.zero_state(batch, dtype=tf.float32)
73
self.in_state = tuple([tf.placeholder_with_default(state, [None, state.shape[1]])
74
for state in zero_states])
75
# this line to calculate the real length of seq
76
# all seq are padded to be of the same length, which is num_steps
77
length = tf.reduce_sum(tf.reduce_max(tf.sign(seq), 2), 1)
78
self.output, self.out_state = tf.nn.dynamic_rnn(cells, seq, length, self.in_state)
79
80
def create_model(self):
81
seq = tf.one_hot(self.seq, len(self.vocab))
82
self.create_rnn(seq)
83
self.logits = tf.layers.dense(self.output, len(self.vocab), None)
84
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits[:, :-1],
85
labels=seq[:, 1:])
86
self.loss = tf.reduce_sum(loss)
87
# sample the next character from Maxwell-Boltzmann Distribution
88
# with temperature temp. It works equally well without tf.exp
89
self.sample = tf.multinomial(tf.exp(self.logits[:, -1] / self.temp), 1)[:, 0]
90
self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.gstep)
91
92
def train(self):
93
saver = tf.train.Saver()
94
start = time.time()
95
min_loss = None
96
with tf.Session() as sess:
97
writer = tf.summary.FileWriter('graphs/gist', sess.graph)
98
sess.run(tf.global_variables_initializer())
99
100
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/' + self.model + '/checkpoint'))
101
if ckpt and ckpt.model_checkpoint_path:
102
saver.restore(sess, ckpt.model_checkpoint_path)
103
104
iteration = self.gstep.eval()
105
stream = read_data(self.path, self.vocab, self.num_steps, overlap=self.num_steps//2)
106
data = read_batch(stream, self.batch_size)
107
while True:
108
batch = next(data)
109
110
# for batch in read_batch(read_data(DATA_PATH, vocab)):
111
batch_loss, _ = sess.run([self.loss, self.opt], {self.seq: batch})
112
if (iteration + 1) % self.skip_step == 0:
113
print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))
114
self.online_infer(sess)
115
start = time.time()
116
checkpoint_name = 'checkpoints/' + self.model + '/char-rnn'
117
if min_loss is None:
118
saver.save(sess, checkpoint_name, iteration)
119
elif batch_loss < min_loss:
120
saver.save(sess, checkpoint_name, iteration)
121
min_loss = batch_loss
122
iteration += 1
123
124
def online_infer(self, sess):
125
""" Generate sequence one character at a time, based on the previous character
126
"""
127
for seed in ['Hillary', 'I', 'R', 'T', '@', 'N', 'M', '.', 'G', 'A', 'W']:
128
sentence = seed
129
state = None
130
for _ in range(self.len_generated):
131
batch = [vocab_encode(sentence[-1], self.vocab)]
132
feed = {self.seq: batch}
133
if state is not None: # for the first decoder step, the state is None
134
for i in range(len(state)):
135
feed.update({self.in_state[i]: state[i]})
136
index, state = sess.run([self.sample, self.out_state], feed)
137
sentence += vocab_decode(index, self.vocab)
138
print('\t' + sentence)
139
140
def main():
141
model = 'trump_tweets'
142
utils.safe_mkdir('checkpoints')
143
utils.safe_mkdir('checkpoints/' + model)
144
145
lm = CharRNN(model)
146
lm.create_model()
147
lm.train()
148
149
if __name__ == '__main__':
150
main()
151