-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathnet.py
69 lines (56 loc) · 2.93 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.legacy_seq2seq as seq2seq
class Net(object):
"""docstring for Net"""
def __init__(self, data, num_units, num_layer, batch_size):
super(Net, self).__init__()
self.num_units = num_units
self.num_layer = num_layer
self.batch_size = batch_size
self.data =data
self.build()
def build(self):
self.inputs = tf.placeholder(tf.int32, [self.batch_size, None])
self.targets = tf.placeholder(tf.int32, [self.batch_size, None])
self.keep_prob = tf.placeholder(tf.float32)
self.seq_len = tf.placeholder(tf.int32, [self.batch_size])
self.learning_rate = tf.placeholder(tf.float64)
with tf.variable_scope('rnn'):
w = tf.get_variable("softmax_w", [self.num_units, self.data.words_size])
b = tf.get_variable("softmax_b", [self.data.words_size])
embedding = tf.get_variable("embedding", [self.data.words_size, self.num_units])
inputs = tf.nn.embedding_lookup(embedding, self.inputs)
self.cell = tf.nn.rnn_cell.MultiRNNCell([self.unit() for _ in range(self.num_layer)])
self.init_state = self.cell.zero_state(self.batch_size, dtype=tf.float32)
output, self.final_state = tf.nn.dynamic_rnn(self.cell,
inputs=inputs,
sequence_length=self.seq_len,
initial_state=self.init_state,
scope='rnn')
with tf.name_scope('fc'):
y = tf.reshape(output, [-1, self.num_units])
logits = tf.matmul(y, w) + b
with tf.name_scope('softmax'):
prob = tf.nn.softmax(logits)
self.prob = tf.reshape(prob, [self.batch_size, -1])
pre = tf.argmax(prob, 1)
self.pre = tf.reshape(pre, [self.batch_size, -1])
targets = tf.reshape(self.targets, [-1])
with tf.name_scope('loss'):
loss = seq2seq.sequence_loss_by_example([logits],
[targets],
[tf.ones_like(targets, dtype=tf.float32)])
self.loss = tf.reduce_mean(loss)
with tf.name_scope('summary'):
tf.summary.scalar('loss', self.loss)
self.merged_summary = tf.summary.merge_all()
with tf.name_scope('optimizer'):
optimizer = tf.train.AdamOptimizer(self.learning_rate)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), 5)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
def unit(self):
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.num_units)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)
return lstm_cell