diff --git a/reduce_model.py b/reduce_model.py index eecd20d..a13bea1 100644 --- a/reduce_model.py +++ b/reduce_model.py @@ -18,16 +18,13 @@ def deprocess(image): return (image + 1) / 2 -def conv(batch_input, out_channels, stride): - with tf.variable_scope('conv'): - in_channels = batch_input.get_shape()[3] - filter = tf.get_variable('filter', [4, 4, in_channels, out_channels], dtype=tf.float32, - initializer=tf.random_normal_initializer(0, 0.02)) - # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels] - # => [batch, out_height, out_width, out_channels] - padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT') - conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding='VALID') - return conv +def gen_conv(batch_input, out_channels): + # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] + initializer = tf.random_normal_initializer(0, 0.02) + # if a.separable_conv: + # return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) + # else: + return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) def lrelu(x, a): @@ -42,31 +39,31 @@ def lrelu(x, a): return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) -def batchnorm(input): - with tf.variable_scope('batchnorm'): - # this block looks like it has 3 inputs on the graph unless we do this - input = tf.identity(input) - - channels = input.get_shape()[3] - offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) - scale = tf.get_variable('scale', [channels], dtype=tf.float32, - initializer=tf.random_normal_initializer(1.0, 0.02)) - mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) - variance_epsilon = 1e-5 - normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) - return normalized - - -def deconv(batch_input, out_channels): - with tf.variable_scope('deconv'): - batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()] - filter = tf.get_variable('filter', [4, 4, out_channels, in_channels], dtype=tf.float32, - initializer=tf.random_normal_initializer(0, 0.02)) - # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels] - # => [batch, out_height, out_width, out_channels] - conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels], - [1, 2, 2, 1], padding='SAME') - return conv +def batchnorm(inputs): + return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02)) + # with tf.variable_scope('batchnorm'): + # # this block looks like it has 3 inputs on the graph unless we do this + # input = tf.identity(input) + # + # channels = input.get_shape()[3] + # offset = tf.get_variable('offset', [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) + # scale = tf.get_variable('scale', [channels], dtype=tf.float32, + # initializer=tf.random_normal_initializer(1.0, 0.02)) + # mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) + # variance_epsilon = 1e-5 + # normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) + # return normalized + + +def gen_deconv(batch_input, out_channels): + # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels] + initializer = tf.random_normal_initializer(0, 0.02) + # if a.separable_conv: + # _b, h, w, _c = batch_input.shape + # resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + # return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer) + # else: + return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer) def process_image(x): @@ -109,7 +106,7 @@ def create_generator(generator_inputs, generator_outputs_channels): # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope('encoder_1'): - output = conv(generator_inputs, ngf, stride=2) + output = gen_conv(generator_inputs, ngf) layers.append(output) layer_specs = [ @@ -126,7 +123,7 @@ def create_generator(generator_inputs, generator_outputs_channels): with tf.variable_scope('encoder_%d' % (len(layers) + 1)): rectified = lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] - convolved = conv(rectified, out_channels, stride=2) + convolved = gen_conv(rectified, out_channels) output = batchnorm(convolved) layers.append(output) @@ -153,7 +150,7 @@ def create_generator(generator_inputs, generator_outputs_channels): rectified = tf.nn.relu(input) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] - output = deconv(rectified, out_channels) + output = gen_deconv(rectified, out_channels) output = batchnorm(output) if dropout > 0.0: @@ -165,7 +162,7 @@ def create_generator(generator_inputs, generator_outputs_channels): with tf.variable_scope('decoder_1'): input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tf.nn.relu(input) - output = deconv(rectified, generator_outputs_channels) + output = gen_deconv(rectified, generator_outputs_channels) output = tf.tanh(output) layers.append(output) @@ -173,7 +170,7 @@ def create_generator(generator_inputs, generator_outputs_channels): def create_model(inputs, targets): - with tf.variable_scope('generator') as scope: + with tf.variable_scope('generator'): # as scope out_channels = int(targets.get_shape()[-1]) outputs = create_generator(inputs, out_channels)