Skip to content

Commit dc91326

Browse files
authored
Merge pull request #43 from CortexFoundation/wlt
Wlt
2 parents 3470799 + 394b85a commit dc91326

File tree

5 files changed

+406
-13
lines changed

5 files changed

+406
-13
lines changed

docs/mrt/im2rec.py

+393
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
from __future__ import print_function
21+
import os
22+
import sys
23+
24+
curr_path = os.path.abspath(os.path.dirname(__file__))
25+
sys.path.append(os.path.join(curr_path, "../python"))
26+
import mxnet as mx
27+
import random
28+
import argparse
29+
import cv2
30+
import time
31+
import traceback
32+
33+
try:
34+
import multiprocessing
35+
except ImportError:
36+
multiprocessing = None
37+
38+
def list_image(root, recursive, exts):
39+
"""Traverses the root of directory that contains images and
40+
generates image list iterator.
41+
Parameters
42+
----------
43+
root: string
44+
recursive: bool
45+
exts: string
46+
Returns
47+
-------
48+
image iterator that contains all the image under the specified path
49+
"""
50+
51+
i = 0
52+
if recursive:
53+
cat = {}
54+
for path, dirs, files in os.walk(root, followlinks=True):
55+
dirs.sort()
56+
files.sort()
57+
for fname in files:
58+
fpath = os.path.join(path, fname)
59+
suffix = os.path.splitext(fname)[1].lower()
60+
if os.path.isfile(fpath) and (suffix in exts):
61+
if path not in cat:
62+
cat[path] = len(cat)
63+
yield (i, os.path.relpath(fpath, root), cat[path])
64+
i += 1
65+
for k, v in sorted(cat.items(), key=lambda x: x[1]):
66+
print(os.path.relpath(k, root), v)
67+
else:
68+
for fname in sorted(os.listdir(root)):
69+
fpath = os.path.join(root, fname)
70+
suffix = os.path.splitext(fname)[1].lower()
71+
if os.path.isfile(fpath) and (suffix in exts):
72+
yield (i, os.path.relpath(fpath, root), 0)
73+
i += 1
74+
75+
def write_list(path_out, image_list):
76+
"""Hepler function to write image list into the file.
77+
The format is as below,
78+
integer_image_index \t float_label_index \t path_to_image
79+
Note that the blank between number and tab is only used for readability.
80+
Parameters
81+
----------
82+
path_out: string
83+
image_list: list
84+
"""
85+
with open(path_out, 'w') as fout:
86+
for i, item in enumerate(image_list):
87+
line = '%d\t' % item[0]
88+
for j in item[2:]:
89+
line += '%f\t' % j
90+
line += '%s\n' % item[1]
91+
fout.write(line)
92+
93+
def make_list(args):
94+
"""Generates .lst file.
95+
Parameters
96+
----------
97+
args: object that contains all the arguments
98+
"""
99+
image_list = list_image(args.root, args.recursive, args.exts)
100+
image_list = list(image_list)
101+
if args.shuffle is True:
102+
random.seed(100)
103+
random.shuffle(image_list)
104+
N = len(image_list)
105+
chunk_size = (N + args.chunks - 1) // args.chunks
106+
for i in range(args.chunks):
107+
chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
108+
if args.chunks > 1:
109+
str_chunk = '_%d' % i
110+
else:
111+
str_chunk = ''
112+
sep = int(chunk_size * args.train_ratio)
113+
sep_test = int(chunk_size * args.test_ratio)
114+
if args.train_ratio == 1.0:
115+
write_list(args.prefix + str_chunk + '.lst', chunk)
116+
else:
117+
if args.test_ratio:
118+
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
119+
if args.train_ratio + args.test_ratio < 1.0:
120+
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
121+
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
122+
123+
def read_list(path_in):
124+
"""Reads the .lst file and generates corresponding iterator.
125+
Parameters
126+
----------
127+
path_in: string
128+
Returns
129+
-------
130+
item iterator that contains information in .lst file
131+
"""
132+
with open(path_in) as fin:
133+
while True:
134+
line = fin.readline()
135+
if not line:
136+
break
137+
line = [i.strip() for i in line.strip().split('\t')]
138+
line_len = len(line)
139+
# check the data format of .lst file
140+
if line_len < 3:
141+
print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line))
142+
continue
143+
try:
144+
item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
145+
except Exception as e:
146+
print('Parsing lst met error for %s, detail: %s' % (line, e))
147+
continue
148+
yield item
149+
150+
def image_encode(args, i, item, q_out):
151+
"""Reads, preprocesses, packs the image and put it back in output queue.
152+
Parameters
153+
----------
154+
args: object
155+
i: int
156+
item: list
157+
q_out: queue
158+
"""
159+
fullpath = os.path.join(args.root, item[1])
160+
161+
if len(item) > 3 and args.pack_label:
162+
header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
163+
else:
164+
header = mx.recordio.IRHeader(0, item[2], item[0], 0)
165+
166+
if args.pass_through:
167+
try:
168+
with open(fullpath, 'rb') as fin:
169+
img = fin.read()
170+
s = mx.recordio.pack(header, img)
171+
q_out.put((i, s, item))
172+
except Exception as e:
173+
traceback.print_exc()
174+
print('pack_img error:', item[1], e)
175+
q_out.put((i, None, item))
176+
return
177+
178+
try:
179+
img = cv2.imread(fullpath, args.color)
180+
except:
181+
traceback.print_exc()
182+
print('imread error trying to load file: %s ' % fullpath)
183+
q_out.put((i, None, item))
184+
return
185+
if img is None:
186+
print('imread read blank (None) image for file: %s' % fullpath)
187+
q_out.put((i, None, item))
188+
return
189+
if args.center_crop:
190+
if img.shape[0] > img.shape[1]:
191+
margin = (img.shape[0] - img.shape[1]) // 2
192+
img = img[margin:margin + img.shape[1], :]
193+
else:
194+
margin = (img.shape[1] - img.shape[0]) // 2
195+
img = img[:, margin:margin + img.shape[0]]
196+
if args.resize:
197+
if img.shape[0] > img.shape[1]:
198+
newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
199+
else:
200+
newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
201+
img = cv2.resize(img, newsize)
202+
203+
try:
204+
s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
205+
q_out.put((i, s, item))
206+
except Exception as e:
207+
traceback.print_exc()
208+
print('pack_img error on file: %s' % fullpath, e)
209+
q_out.put((i, None, item))
210+
return
211+
212+
def read_worker(args, q_in, q_out):
213+
"""Function that will be spawned to fetch the image
214+
from the input queue and put it back to output queue.
215+
Parameters
216+
----------
217+
args: object
218+
q_in: queue
219+
q_out: queue
220+
"""
221+
while True:
222+
deq = q_in.get()
223+
if deq is None:
224+
break
225+
i, item = deq
226+
image_encode(args, i, item, q_out)
227+
228+
def write_worker(q_out, fname, working_dir):
229+
"""Function that will be spawned to fetch processed image
230+
from the output queue and write to the .rec file.
231+
Parameters
232+
----------
233+
q_out: queue
234+
fname: string
235+
working_dir: string
236+
"""
237+
pre_time = time.time()
238+
count = 0
239+
fname = os.path.basename(fname)
240+
fname_rec = os.path.splitext(fname)[0] + '.rec'
241+
fname_idx = os.path.splitext(fname)[0] + '.idx'
242+
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
243+
os.path.join(working_dir, fname_rec), 'w')
244+
buf = {}
245+
more = True
246+
while more:
247+
deq = q_out.get()
248+
if deq is not None:
249+
i, s, item = deq
250+
buf[i] = (s, item)
251+
else:
252+
more = False
253+
while count in buf:
254+
s, item = buf[count]
255+
del buf[count]
256+
if s is not None:
257+
record.write_idx(item[0], s)
258+
259+
if count % 1000 == 0:
260+
cur_time = time.time()
261+
print('time:', cur_time - pre_time, ' count:', count)
262+
pre_time = cur_time
263+
count += 1
264+
265+
def parse_args():
266+
"""Defines all arguments.
267+
Returns
268+
-------
269+
args object that contains all the params
270+
"""
271+
parser = argparse.ArgumentParser(
272+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
273+
description='Create an image list or \
274+
make a record database by reading from an image list')
275+
parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
276+
parser.add_argument('root', help='path to folder containing images.')
277+
278+
cgroup = parser.add_argument_group('Options for creating image lists')
279+
cgroup.add_argument('--list', action='store_true',
280+
help='If this is set im2rec will create image list(s) by traversing root folder\
281+
and output to <prefix>.lst.\
282+
Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')
283+
cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
284+
help='list of acceptable image extensions.')
285+
cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
286+
cgroup.add_argument('--train-ratio', type=float, default=1.0,
287+
help='Ratio of images to use for training.')
288+
cgroup.add_argument('--test-ratio', type=float, default=0,
289+
help='Ratio of images to use for testing.')
290+
cgroup.add_argument('--recursive', action='store_true',
291+
help='If true recursively walk through subdirs and assign an unique label\
292+
to images in each folder. Otherwise only include images in the root folder\
293+
and give them label 0.')
294+
cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',
295+
help='If this is passed, \
296+
im2rec will not randomize the image order in <prefix>.lst')
297+
rgroup = parser.add_argument_group('Options for creating database')
298+
rgroup.add_argument('--pass-through', action='store_true',
299+
help='whether to skip transformation and save image as is')
300+
rgroup.add_argument('--resize', type=int, default=0,
301+
help='resize the shorter edge of image to the newsize, original images will\
302+
be packed by default.')
303+
rgroup.add_argument('--center-crop', action='store_true',
304+
help='specify whether to crop the center image to make it rectangular.')
305+
rgroup.add_argument('--quality', type=int, default=95,
306+
help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
307+
rgroup.add_argument('--num-thread', type=int, default=1,
308+
help='number of thread to use for encoding. order of images will be different\
309+
from the input list if >1. the input list will be modified to match the\
310+
resulting order.')
311+
rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
312+
help='specify the color mode of the loaded image.\
313+
1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
314+
0: Loads image in grayscale mode.\
315+
-1:Loads image as such including alpha channel.')
316+
rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
317+
help='specify the encoding of the images.')
318+
rgroup.add_argument('--pack-label', action='store_true',
319+
help='Whether to also pack multi dimensional label in the record file')
320+
args = parser.parse_args()
321+
args.prefix = os.path.abspath(args.prefix)
322+
args.root = os.path.abspath(args.root)
323+
return args
324+
325+
if __name__ == '__main__':
326+
args = parse_args()
327+
# if the '--list' is used, it generates .lst file
328+
if args.list:
329+
make_list(args)
330+
# otherwise read .lst file to generates .rec file
331+
else:
332+
if os.path.isdir(args.prefix):
333+
working_dir = args.prefix
334+
else:
335+
working_dir = os.path.dirname(args.prefix)
336+
files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
337+
if os.path.isfile(os.path.join(working_dir, fname))]
338+
count = 0
339+
for fname in files:
340+
if fname.startswith(args.prefix) and fname.endswith('.lst'):
341+
print('Creating .rec file from', fname, 'in', working_dir)
342+
count += 1
343+
image_list = read_list(fname)
344+
# -- write_record -- #
345+
if args.num_thread > 1 and multiprocessing is not None:
346+
q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
347+
q_out = multiprocessing.Queue(1024)
348+
# define the process
349+
read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
350+
for i in range(args.num_thread)]
351+
# process images with num_thread process
352+
for p in read_process:
353+
p.start()
354+
# only use one process to write .rec to avoid race-condtion
355+
write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
356+
write_process.start()
357+
# put the image list into input queue
358+
for i, item in enumerate(image_list):
359+
q_in[i % len(q_in)].put((i, item))
360+
for q in q_in:
361+
q.put(None)
362+
for p in read_process:
363+
p.join()
364+
365+
q_out.put(None)
366+
write_process.join()
367+
else:
368+
print('multiprocessing not available, fall back to single threaded encoding')
369+
try:
370+
import Queue as queue
371+
except ImportError:
372+
import queue
373+
q_out = queue.Queue()
374+
fname = os.path.basename(fname)
375+
fname_rec = os.path.splitext(fname)[0] + '.rec'
376+
fname_idx = os.path.splitext(fname)[0] + '.idx'
377+
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
378+
os.path.join(working_dir, fname_rec), 'w')
379+
cnt = 0
380+
pre_time = time.time()
381+
for i, item in enumerate(image_list):
382+
image_encode(args, i, item, q_out)
383+
if q_out.empty():
384+
continue
385+
_, s, _ = q_out.get()
386+
record.write_idx(item[0], s)
387+
if cnt % 1000 == 0:
388+
cur_time = time.time()
389+
print('time:', cur_time - pre_time, ' count:', cnt)
390+
pre_time = cur_time
391+
cnt += 1
392+
if not count:
393+
print('Did not find and list file with prefix %s'%args.prefix)

0 commit comments

Comments
 (0)