Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96145
License: OTHER
1
""" Solution for simple linear regression example using tf.data
2
Created by Chip Huyen ([email protected])
3
CS20: "TensorFlow for Deep Learning Research"
4
cs20.stanford.edu
5
Lecture 03
6
"""
7
import os
8
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
9
import time
10
11
import numpy as np
12
import matplotlib.pyplot as plt
13
import tensorflow as tf
14
15
import utils
16
17
DATA_FILE = 'data/birth_life_2010.txt'
18
19
# Step 1: read in the data
20
data, n_samples = utils.read_birth_life_data(DATA_FILE)
21
22
# Step 2: create Dataset and iterator
23
dataset = tf.data.Dataset.from_tensor_slices((data[:,0], data[:,1]))
24
25
iterator = dataset.make_initializable_iterator()
26
X, Y = iterator.get_next()
27
28
# Step 3: create weight and bias, initialized to 0
29
w = tf.get_variable('weights', initializer=tf.constant(0.0))
30
b = tf.get_variable('bias', initializer=tf.constant(0.0))
31
32
# Step 4: build model to predict Y
33
Y_predicted = X * w + b
34
35
# Step 5: use the square error as the loss function
36
loss = tf.square(Y - Y_predicted, name='loss')
37
# loss = utils.huber_loss(Y, Y_predicted)
38
39
# Step 6: using gradient descent with learning rate of 0.001 to minimize loss
40
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
41
42
start = time.time()
43
with tf.Session() as sess:
44
# Step 7: initialize the necessary variables, in this case, w and b
45
sess.run(tf.global_variables_initializer())
46
writer = tf.summary.FileWriter('./graphs/linear_reg', sess.graph)
47
48
# Step 8: train the model for 100 epochs
49
for i in range(100):
50
sess.run(iterator.initializer) # initialize the iterator
51
total_loss = 0
52
try:
53
while True:
54
_, l = sess.run([optimizer, loss])
55
total_loss += l
56
except tf.errors.OutOfRangeError:
57
pass
58
59
print('Epoch {0}: {1}'.format(i, total_loss/n_samples))
60
61
# close the writer when you're done using it
62
writer.close()
63
64
# Step 9: output the values of w and b
65
w_out, b_out = sess.run([w, b])
66
print('w: %f, b: %f' %(w_out, b_out))
67
print('Took: %f seconds' %(time.time() - start))
68
69
# plot the results
70
plt.plot(data[:,0], data[:,1], 'bo', label='Real data')
71
plt.plot(data[:,0], data[:,0] * w_out + b_out, 'r', label='Predicted data with squared error')
72
# plt.plot(data[:,0], data[:,0] * (-5.883589) + 85.124306, 'g', label='Predicted data with Huber loss')
73
plt.legend()
74
plt.show()
75