Skip to content

Commit

Permalink
update conv&deconv for new pix2pix version
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Jul 9, 2018
1 parent d7e52d3 commit 90c6bd6
Showing 1 changed file with 37 additions and 40 deletions.
77 changes: 37 additions & 40 deletions reduce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@ def deprocess(image):
return (image + 1) / 2


def conv(batch_input, out_channels, stride):
with tf.variable_scope('conv'):
in_channels = batch_input.get_shape()[3]
filter = tf.get_variable('filter', [4, 4, in_channels, out_channels], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.02))
# [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
# => [batch, out_height, out_width, out_channels]
padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding='VALID')
return conv
def gen_conv(batch_input, out_channels):
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
initializer = tf.random_normal_initializer(0, 0.02)
# if a.separable_conv:
# return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
# else:
return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)


def lrelu(x, a):
Expand All @@ -42,31 +39,31 @@ def lrelu(x, a):
return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)


def batchnorm(input):
with tf.variable_scope('batchnorm'):
# this block looks like it has 3 inputs on the graph unless we do this
input = tf.identity(input)

channels = input.get_shape()[3]
offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
scale = tf.get_variable('scale', [channels], dtype=tf.float32,
initializer=tf.random_normal_initializer(1.0, 0.02))
mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
variance_epsilon = 1e-5
normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
return normalized


def deconv(batch_input, out_channels):
with tf.variable_scope('deconv'):
batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
filter = tf.get_variable('filter', [4, 4, out_channels, in_channels], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.02))
# [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
# => [batch, out_height, out_width, out_channels]
conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels],
[1, 2, 2, 1], padding='SAME')
return conv
def batchnorm(inputs):
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
# with tf.variable_scope('batchnorm'):
# # this block looks like it has 3 inputs on the graph unless we do this
# input = tf.identity(input)
#
# channels = input.get_shape()[3]
# offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
# scale = tf.get_variable('scale', [channels], dtype=tf.float32,
# initializer=tf.random_normal_initializer(1.0, 0.02))
# mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
# variance_epsilon = 1e-5
# normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
# return normalized


def gen_deconv(batch_input, out_channels):
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
initializer = tf.random_normal_initializer(0, 0.02)
# if a.separable_conv:
# _b, h, w, _c = batch_input.shape
# resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
# else:
return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)


def process_image(x):
Expand Down Expand Up @@ -109,7 +106,7 @@ def create_generator(generator_inputs, generator_outputs_channels):

# encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
with tf.variable_scope('encoder_1'):
output = conv(generator_inputs, ngf, stride=2)
output = gen_conv(generator_inputs, ngf)
layers.append(output)

layer_specs = [
Expand All @@ -126,7 +123,7 @@ def create_generator(generator_inputs, generator_outputs_channels):
with tf.variable_scope('encoder_%d' % (len(layers) + 1)):
rectified = lrelu(layers[-1], 0.2)
# [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
convolved = conv(rectified, out_channels, stride=2)
convolved = gen_conv(rectified, out_channels)
output = batchnorm(convolved)
layers.append(output)

Expand All @@ -153,7 +150,7 @@ def create_generator(generator_inputs, generator_outputs_channels):

rectified = tf.nn.relu(input)
# [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
output = deconv(rectified, out_channels)
output = gen_deconv(rectified, out_channels)
output = batchnorm(output)

if dropout > 0.0:
Expand All @@ -165,15 +162,15 @@ def create_generator(generator_inputs, generator_outputs_channels):
with tf.variable_scope('decoder_1'):
input = tf.concat([layers[-1], layers[0]], axis=3)
rectified = tf.nn.relu(input)
output = deconv(rectified, generator_outputs_channels)
output = gen_deconv(rectified, generator_outputs_channels)
output = tf.tanh(output)
layers.append(output)

return layers[-1]


def create_model(inputs, targets):
with tf.variable_scope('generator') as scope:
with tf.variable_scope('generator'): # as scope
out_channels = int(targets.get_shape()[-1])
outputs = create_generator(inputs, out_channels)

Expand Down

0 comments on commit 90c6bd6

Please sign in to comment.