Skip to content

Commit 048c6de

Browse files
committed
Added DeepRNN
1 parent 77585b4 commit 048c6de

31 files changed

+2392
-0
lines changed

.gitattributes

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*npy filter=lfs diff=lfs merge=lfs -text
2+
*csv filter=lfs diff=lfs merge=lfs -text
3+
*gz filter=lfs diff=lfs merge=lfs -text

DeepRNN/base_model.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import os
2+
import numpy as np
3+
import pandas as pd
4+
import tensorflow as tf
5+
import matplotlib.pyplot as plt
6+
import cPickle as pickle
7+
import copy
8+
import json
9+
from tqdm import tqdm
10+
11+
from utils.nn import NN
12+
from utils.coco.coco import COCO
13+
from utils.coco.pycocoevalcap.eval import COCOEvalCap
14+
from utils.misc import ImageLoader, CaptionData, TopN
15+
16+
class BaseModel(object):
17+
def __init__(self, config):
18+
self.config = config
19+
self.is_train = True if config.phase == 'train' else False
20+
self.train_cnn = self.is_train and config.train_cnn
21+
self.image_loader = ImageLoader('./DeepRNN/utils/ilsvrc_2012_mean.npy')
22+
self.image_shape = [224, 224, 3]
23+
self.nn = NN(config)
24+
self.global_step = tf.Variable(0,
25+
name = 'global_step',
26+
trainable = False)
27+
self.build()
28+
29+
def build(self):
30+
raise NotImplementedError()
31+
32+
def test(self, sess, test_data, vocabulary):
33+
""" Test the model using any given images. """
34+
config = self.config
35+
36+
# Generate the captions for the images
37+
for k in tqdm(list(range(test_data.num_batches)), desc='path'):
38+
batch = test_data.next_batch()
39+
caption_data = self.beam_search(sess, batch, vocabulary)
40+
41+
fake_cnt = 0 if k<test_data.num_batches-1 \
42+
else test_data.fake_count
43+
for l in range(test_data.batch_size-fake_cnt):
44+
word_idxs = caption_data[l][0].sentence
45+
score = caption_data[l][0].score
46+
caption = vocabulary.get_sentence(word_idxs)
47+
print('**'+caption+'**')
48+
49+
def beam_search(self, sess, image_files, vocabulary):
50+
"""Use beam search to generate the captions for a batch of images."""
51+
# Feed in the images to get the contexts and the initial LSTM states
52+
config = self.config
53+
images = self.image_loader.load_images(image_files)
54+
contexts, initial_memory, initial_output = sess.run(
55+
[self.conv_feats, self.initial_memory, self.initial_output],
56+
feed_dict = {self.images: images})
57+
58+
partial_caption_data = []
59+
complete_caption_data = []
60+
for k in range(config.batch_size):
61+
initial_beam = CaptionData(sentence = [],
62+
memory = initial_memory[k],
63+
output = initial_output[k],
64+
score = 1.0)
65+
partial_caption_data.append(TopN(config.beam_size))
66+
partial_caption_data[-1].push(initial_beam)
67+
complete_caption_data.append(TopN(config.beam_size))
68+
69+
# Run beam search
70+
for idx in range(config.max_caption_length):
71+
partial_caption_data_lists = []
72+
for k in range(config.batch_size):
73+
data = partial_caption_data[k].extract()
74+
partial_caption_data_lists.append(data)
75+
partial_caption_data[k].reset()
76+
77+
num_steps = 1 if idx == 0 else config.beam_size
78+
for b in range(num_steps):
79+
if idx == 0:
80+
last_word = np.zeros((config.batch_size), np.int32)
81+
else:
82+
last_word = np.array([pcl[b].sentence[-1]
83+
for pcl in partial_caption_data_lists],
84+
np.int32)
85+
86+
last_memory = np.array([pcl[b].memory
87+
for pcl in partial_caption_data_lists],
88+
np.float32)
89+
last_output = np.array([pcl[b].output
90+
for pcl in partial_caption_data_lists],
91+
np.float32)
92+
93+
memory, output, scores = sess.run(
94+
[self.memory, self.output, self.probs],
95+
feed_dict = {self.contexts: contexts,
96+
self.last_word: last_word,
97+
self.last_memory: last_memory,
98+
self.last_output: last_output})
99+
100+
# Find the beam_size most probable next words
101+
for k in range(config.batch_size):
102+
caption_data = partial_caption_data_lists[k][b]
103+
words_and_scores = list(enumerate(scores[k]))
104+
words_and_scores.sort(key=lambda x: -x[1])
105+
words_and_scores = words_and_scores[0:config.beam_size+1]
106+
107+
# Append each of these words to the current partial caption
108+
for w, s in words_and_scores:
109+
sentence = caption_data.sentence + [w]
110+
score = caption_data.score * s
111+
beam = CaptionData(sentence,
112+
memory[k],
113+
output[k],
114+
score)
115+
if vocabulary.words[w] == '.':
116+
complete_caption_data[k].push(beam)
117+
else:
118+
partial_caption_data[k].push(beam)
119+
120+
results = []
121+
for k in range(config.batch_size):
122+
if complete_caption_data[k].size() == 0:
123+
complete_caption_data[k] = partial_caption_data[k]
124+
results.append(complete_caption_data[k].extract(sort=True))
125+
126+
return results
127+
128+
def load(self, sess, model_file=None):
129+
""" Load the model. """
130+
config = self.config
131+
if model_file is not None:
132+
save_path = model_file
133+
else:
134+
info_path = os.path.join(config.save_dir, "config.pickle")
135+
info_file = open(info_path, "rb")
136+
config = pickle.load(info_file)
137+
global_step = config.global_step
138+
info_file.close()
139+
save_path = os.path.join(config.save_dir,
140+
str(global_step)+".npy")
141+
142+
data_dict = np.load(save_path).item()
143+
count = 0
144+
for v in tqdm(tf.global_variables()):
145+
if v.name in data_dict.keys():
146+
sess.run(v.assign(data_dict[v.name]))
147+
count += 1
148+
149+
def load_cnn(self, session, data_path, ignore_missing=True):
150+
""" Load a pretrained CNN model. """
151+
data_dict = np.load(data_path).item()
152+
count = 0
153+
for op_name in tqdm(data_dict):
154+
with tf.variable_scope(op_name, reuse = True):
155+
for param_name, data in data_dict[op_name].iteritems():
156+
try:
157+
var = tf.get_variable(param_name)
158+
session.run(var.assign(data))
159+
count += 1
160+
except ValueError:
161+
pass

DeepRNN/config.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
class Config(object):
3+
""" Wrapper class for various (hyper)parameters. """
4+
def __init__(self):
5+
# about the model architecture
6+
self.cnn = 'vgg16' # 'vgg16' or 'resnet50'
7+
self.max_caption_length = 20
8+
self.dim_embedding = 512
9+
self.num_lstm_units = 512
10+
self.num_initalize_layers = 2 # 1 or 2
11+
self.dim_initalize_layer = 512
12+
self.num_attend_layers = 2 # 1 or 2
13+
self.dim_attend_layer = 512
14+
self.num_decode_layers = 2 # 1 or 2
15+
self.dim_decode_layer = 1024
16+
17+
# about the weight initialization and regularization
18+
self.fc_kernel_initializer_scale = 0.08
19+
self.fc_kernel_regularizer_scale = 1e-4
20+
self.fc_activity_regularizer_scale = 0.0
21+
self.conv_kernel_regularizer_scale = 1e-4
22+
self.conv_activity_regularizer_scale = 0.0
23+
self.fc_drop_rate = 0.5
24+
self.lstm_drop_rate = 0.3
25+
self.attention_loss_factor = 0.01
26+
27+
# about the optimization
28+
self.num_epochs = 100
29+
self.batch_size = 32
30+
self.optimizer = 'Adam' # 'Adam', 'RMSProp', 'Momentum' or 'SGD'
31+
self.initial_learning_rate = 0.0001
32+
self.learning_rate_decay_factor = 1.0
33+
self.num_steps_per_decay = 100000
34+
self.momentum = 0.0
35+
self.clip_gradients = 5.0
36+
self.use_nesterov = True
37+
self.decay = 0.9
38+
self.centered = True
39+
self.beta2 = 0.999
40+
self.beta1 = 0.9
41+
self.epsilon = 1e-6
42+
43+
# about the vocabulary
44+
self.vocabulary_file = './data/vocabulary.csv'
45+
self.vocabulary_size = 5000

DeepRNN/dataset.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import math
3+
import numpy as np
4+
import pandas as pd
5+
from tqdm import tqdm
6+
7+
from utils.coco.coco import COCO
8+
from utils.vocabulary import Vocabulary
9+
10+
class DataSet(object):
11+
def __init__(self,
12+
image_ids,
13+
image_files,
14+
batch_size,
15+
word_idxs=None,
16+
masks=None,
17+
is_train=False,
18+
shuffle=False):
19+
self.image_ids = np.array(image_ids)
20+
self.image_files = np.array(image_files)
21+
self.word_idxs = np.array(word_idxs)
22+
self.masks = np.array(masks)
23+
self.batch_size = batch_size
24+
self.is_train = is_train
25+
self.shuffle = shuffle
26+
self.setup()
27+
28+
def setup(self):
29+
""" Setup the dataset. """
30+
self.count = len(self.image_ids)
31+
self.num_batches = int(np.ceil(self.count * 1.0 / self.batch_size))
32+
self.fake_count = self.num_batches * self.batch_size - self.count
33+
self.idxs = list(range(self.count))
34+
self.reset()
35+
36+
def reset(self):
37+
""" Reset the dataset. """
38+
self.current_idx = 0
39+
if self.shuffle:
40+
np.random.shuffle(self.idxs)
41+
42+
def next_batch(self):
43+
""" Fetch the next batch. """
44+
assert self.has_next_batch()
45+
46+
if self.has_full_next_batch():
47+
start, end = self.current_idx, \
48+
self.current_idx + self.batch_size
49+
current_idxs = self.idxs[start:end]
50+
else:
51+
start, end = self.current_idx, self.count
52+
current_idxs = self.idxs[start:end] + \
53+
list(np.random.choice(self.count, self.fake_count))
54+
55+
image_files = self.image_files[current_idxs]
56+
if self.is_train:
57+
word_idxs = self.word_idxs[current_idxs]
58+
masks = self.masks[current_idxs]
59+
self.current_idx += self.batch_size
60+
return image_files, word_idxs, masks
61+
else:
62+
self.current_idx += self.batch_size
63+
return image_files
64+
65+
def has_next_batch(self):
66+
""" Determine whether there is a batch left. """
67+
return self.current_idx < self.count
68+
69+
def has_full_next_batch(self):
70+
""" Determine whether there is a full batch left. """
71+
return self.current_idx + self.batch_size <= self.count
72+
73+
def prepare_test_data(config):
74+
""" Prepare the data for testing the model. """
75+
image_files = [config.test_file_name]
76+
image_ids = list(range(len(image_files)))
77+
if os.path.exists(config.vocabulary_file):
78+
vocabulary = Vocabulary(config.vocabulary_size,
79+
config.vocabulary_file)
80+
else:
81+
vocabulary = build_vocabulary(config)
82+
dataset = DataSet(image_ids, image_files, config.batch_size)
83+
return dataset, vocabulary

DeepRNN/main.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/python
2+
import tensorflow as tf
3+
from config import Config
4+
from model import CaptionGenerator
5+
from dataset import prepare_test_data
6+
7+
flags = tf.app.flags.FLAGS
8+
9+
tf.flags.DEFINE_string('test_image', 'image.jpg', 'Test image name')
10+
11+
def main(argv):
12+
config = Config()
13+
config.test_file_name = flags.test_image
14+
config.phase = 'test'
15+
config.beam_size = 3
16+
17+
with tf.Session() as sess:
18+
data, vocabulary = prepare_test_data(config)
19+
model = CaptionGenerator(config)
20+
model.load(sess, './data/289999.npy')
21+
tf.get_default_graph().finalize()
22+
model.test(sess, data, vocabulary)
23+
24+
if __name__ == '__main__':
25+
tf.app.run()

0 commit comments

Comments
 (0)