Skip to content

Commit 96077ee

Browse files
author
Kai Sheng Tai
committed
Allow optimization on CPU, adjust learning rate for VGG so that SGD doesn't diverge
1 parent 589cdba commit 96077ee

File tree

6 files changed

+52
-20
lines changed

6 files changed

+52
-20
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.DS_Store
22
frames
3-
*~
3+
*~
4+
*.th

costs.lua

+5-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ end
3636

3737
local euclidean = nn.MSECriterion()
3838
euclidean.sizeAverage = false
39-
euclidean:cuda()
39+
if opt.cpu then
40+
euclidean:float()
41+
else
42+
euclidean:cuda()
43+
end
4044

4145
function style_grad(gen, orig_gram)
4246
local k = gen:size(2)

images.lua

+12-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ function preprocess(img, scale)
2323
end
2424
end
2525

26-
local copy = torch.Tensor(img:size())
26+
-- reverse channels
27+
local copy = torch.FloatTensor(img:size())
2728
copy[1] = img[3]
2829
copy[2] = img[2]
2930
copy[3] = img[1]
@@ -37,17 +38,17 @@ function preprocess(img, scale)
3738
end
3839

3940
function depreprocess(img)
40-
img = img:float():view(3, img:size(3), img:size(4))
41+
local copy = torch.FloatTensor(3, img:size(3), img:size(4)):copy(img)
4142
for i = 1, 3 do
42-
img[i]:add(means[i])
43+
copy[i]:add(means[i])
4344
end
44-
img:div(255)
45+
copy:div(255)
4546

46-
local copy = torch.FloatTensor(img:size())
47-
copy[1] = img[3]
48-
copy[2] = img[2]
49-
copy[3] = img[1]
50-
img = copy
51-
img:clamp(0, 1)
52-
return img
47+
-- reverse channels
48+
local copy2 = torch.FloatTensor(copy:size())
49+
copy2[1] = copy[3]
50+
copy2[2] = copy[2]
51+
copy2[3] = copy[1]
52+
copy2:clamp(0, 1)
53+
return copy2
5354
end

main.lua

+33-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cmd:option('--smoothness', 0, 'Total variation norm regularization s
3131
cmd:option('--init', 'image', '{image, random}. Initialization mode for optimized image.')
3232
cmd:option('--backend', 'cunn', '{cunn, cudnn}. Neural network CUDA backend.')
3333
cmd:option('--optimizer', 'lbfgs', '{sgd, lbfgs}. Optimization algorithm.')
34+
cmd:option('--cpu', false, 'Optimize on CPU (only with VGG network).')
3435
opt = cmd:parse(arg)
3536
if opt.size <= 0 then
3637
opt.size = nil
@@ -51,6 +52,10 @@ if opt.model == 'inception' then
5152
print('run download_models.sh to download model weights')
5253
error('')
5354
end
55+
56+
if opt.cpu then
57+
error('CPU optimization only works with VGG model')
58+
end
5459
elseif opt.model == 'vgg' then
5560
if not paths.filep(vgg_path) then
5661
print('ERROR: could not find VGG model weights at ' .. vgg_path)
@@ -97,6 +102,13 @@ elseif opt.model == 'vgg' then
97102

98103
model = create_vgg(vgg_path, opt.backend)
99104
end
105+
106+
-- run on GPU
107+
if opt.cpu then
108+
model:float()
109+
else
110+
model:cuda()
111+
end
100112
collectgarbage()
101113

102114
-- compute normalization factor
@@ -111,14 +123,20 @@ for k, v in pairs(content_weights) do
111123
end
112124

113125
-- load content image
114-
local img = preprocess(image.load(opt.content), opt.size):cuda()
126+
local img = preprocess(image.load(opt.content), opt.size)
127+
if not opt.cpu then
128+
img = img:cuda()
129+
end
115130
model:forward(img)
116131
local img_activations, _ = collect_activations(model, content_weights, {})
117132

118133
-- load style image
119134
local art = preprocess(
120135
image.load(opt.style), math.max(img:size(3), img:size(4))
121-
):cuda()
136+
)
137+
if not opt.cpu then
138+
art = art:cuda()
139+
end
122140
model:forward(art)
123141
local _, art_grams = collect_activations(model, {}, style_weights)
124142
art = nil
@@ -130,7 +148,8 @@ function opfunc(input)
130148

131149
-- backpropagate
132150
local loss = 0
133-
local grad = torch.CudaTensor(model.output:size()):zero()
151+
local grad = opt.cpu and torch.FloatTensor() or torch.CudaTensor()
152+
grad:resize(model.output:size()):zero()
134153
for i = #model.modules, 1, -1 do
135154
local module_input = (i == 1) and input or model.modules[i - 1].output
136155
local module = model.modules[i]
@@ -168,7 +187,11 @@ if opt.init == 'image' then
168187
elseif opt.init == 'random' then
169188
input = preprocess(
170189
torch.randn(3, img:size(3), img:size(4)):mul(0.1):add(0.5):clamp(0, 1)
171-
):cuda()
190+
)
191+
192+
if not opt.cpu then
193+
input = input:cuda()
194+
end
172195
else
173196
error('unrecognized initialization option: ' .. opt.init)
174197
end
@@ -190,10 +213,15 @@ image.save(paths.concat(frames_dir, '0.jpg'), output)
190213
local optim_state
191214
if opt.optimizer == 'sgd' then
192215
optim_state = {
193-
learningRate = 0.1,
194216
momentum = 0.9,
195217
dampening = 0.0,
196218
}
219+
220+
if opt.model == 'inception' then
221+
optim_state.learningRate = 5e-2
222+
else
223+
optim_state.learningRate = 1e-3
224+
end
197225
elseif opt.optimizer == 'lbfgs' then
198226
optim_state = {
199227
maxIter = 3,

models/inception.lua

-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,5 @@ function create_inception(weights_file, backend)
9797

9898
model = model:subnetwork('inception_4e')
9999
collectgarbage()
100-
model:cuda()
101100
return model
102101
end

models/vgg19.lua

-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,5 @@ function create_vgg(weights_file, backend)
6363
end
6464

6565
collectgarbage()
66-
model:cuda()
6766
return model
6867
end

0 commit comments

Comments
 (0)