forked from SullyChen/Autopilot-TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
55 lines (42 loc) · 1.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import tensorflow as tf
import driving_data
import model
LOGDIR = './save'
sess = tf.InteractiveSession()
L2NormConst = 0.001
train_vars = tf.trainable_variables()
loss = tf.reduce_mean(tf.square(tf.sub(model.y_, model.y))) + tf.add_n([tf.nn.l2_loss(v) for v in train_vars]) * L2NormConst
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
sess.run(tf.initialize_all_variables())
# create a summary to monitor cost tensor
tf.scalar_summary("loss", loss)
# merge all summaries into a single op
merged_summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
# op to write logs to Tensorboard
logs_path = './logs'
summary_writer = tf.train.SummaryWriter(logs_path, graph=tf.get_default_graph())
epochs = 30
batch_size = 100
# train over the dataset about 30 times
for epoch in range(epochs):
for i in range(int(driving_data.num_images/batch_size)):
xs, ys = driving_data.LoadTrainBatch(batch_size)
train_step.run(feed_dict={model.x: xs, model.y_: ys, model.keep_prob: 0.8})
if i % 10 == 0:
xs, ys = driving_data.LoadValBatch(batch_size)
loss_value = loss.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0})
print("Epoch: %d, Step: %d, Loss: %g" % (epoch, epoch * batch_size + i, loss_value))
# write logs at every iteration
summary = merged_summary_op.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0})
summary_writer.add_summary(summary, epoch * batch_size + i)
if i % batch_size == 0:
if not os.path.exists(LOGDIR):
os.makedirs(LOGDIR)
checkpoint_path = os.path.join(LOGDIR, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print("Model saved in file: %s" % filename)
print("Run the command line:\n" \
"--> tensorboard --logdir=./logs " \
"\nThen open http://0.0.0.0:6006/ into your web browser")