|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# "License"); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | + |
| 20 | +from __future__ import print_function |
| 21 | +import os |
| 22 | +import sys |
| 23 | + |
| 24 | +curr_path = os.path.abspath(os.path.dirname(__file__)) |
| 25 | +sys.path.append(os.path.join(curr_path, "../python")) |
| 26 | +import mxnet as mx |
| 27 | +import random |
| 28 | +import argparse |
| 29 | +import cv2 |
| 30 | +import time |
| 31 | +import traceback |
| 32 | + |
| 33 | +try: |
| 34 | + import multiprocessing |
| 35 | +except ImportError: |
| 36 | + multiprocessing = None |
| 37 | + |
| 38 | +def list_image(root, recursive, exts): |
| 39 | + """Traverses the root of directory that contains images and |
| 40 | + generates image list iterator. |
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + root: string |
| 44 | + recursive: bool |
| 45 | + exts: string |
| 46 | + Returns |
| 47 | + ------- |
| 48 | + image iterator that contains all the image under the specified path |
| 49 | + """ |
| 50 | + |
| 51 | + i = 0 |
| 52 | + if recursive: |
| 53 | + cat = {} |
| 54 | + for path, dirs, files in os.walk(root, followlinks=True): |
| 55 | + dirs.sort() |
| 56 | + files.sort() |
| 57 | + for fname in files: |
| 58 | + fpath = os.path.join(path, fname) |
| 59 | + suffix = os.path.splitext(fname)[1].lower() |
| 60 | + if os.path.isfile(fpath) and (suffix in exts): |
| 61 | + if path not in cat: |
| 62 | + cat[path] = len(cat) |
| 63 | + yield (i, os.path.relpath(fpath, root), cat[path]) |
| 64 | + i += 1 |
| 65 | + for k, v in sorted(cat.items(), key=lambda x: x[1]): |
| 66 | + print(os.path.relpath(k, root), v) |
| 67 | + else: |
| 68 | + for fname in sorted(os.listdir(root)): |
| 69 | + fpath = os.path.join(root, fname) |
| 70 | + suffix = os.path.splitext(fname)[1].lower() |
| 71 | + if os.path.isfile(fpath) and (suffix in exts): |
| 72 | + yield (i, os.path.relpath(fpath, root), 0) |
| 73 | + i += 1 |
| 74 | + |
| 75 | +def write_list(path_out, image_list): |
| 76 | + """Hepler function to write image list into the file. |
| 77 | + The format is as below, |
| 78 | + integer_image_index \t float_label_index \t path_to_image |
| 79 | + Note that the blank between number and tab is only used for readability. |
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + path_out: string |
| 83 | + image_list: list |
| 84 | + """ |
| 85 | + with open(path_out, 'w') as fout: |
| 86 | + for i, item in enumerate(image_list): |
| 87 | + line = '%d\t' % item[0] |
| 88 | + for j in item[2:]: |
| 89 | + line += '%f\t' % j |
| 90 | + line += '%s\n' % item[1] |
| 91 | + fout.write(line) |
| 92 | + |
| 93 | +def make_list(args): |
| 94 | + """Generates .lst file. |
| 95 | + Parameters |
| 96 | + ---------- |
| 97 | + args: object that contains all the arguments |
| 98 | + """ |
| 99 | + image_list = list_image(args.root, args.recursive, args.exts) |
| 100 | + image_list = list(image_list) |
| 101 | + if args.shuffle is True: |
| 102 | + random.seed(100) |
| 103 | + random.shuffle(image_list) |
| 104 | + N = len(image_list) |
| 105 | + chunk_size = (N + args.chunks - 1) // args.chunks |
| 106 | + for i in range(args.chunks): |
| 107 | + chunk = image_list[i * chunk_size:(i + 1) * chunk_size] |
| 108 | + if args.chunks > 1: |
| 109 | + str_chunk = '_%d' % i |
| 110 | + else: |
| 111 | + str_chunk = '' |
| 112 | + sep = int(chunk_size * args.train_ratio) |
| 113 | + sep_test = int(chunk_size * args.test_ratio) |
| 114 | + if args.train_ratio == 1.0: |
| 115 | + write_list(args.prefix + str_chunk + '.lst', chunk) |
| 116 | + else: |
| 117 | + if args.test_ratio: |
| 118 | + write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) |
| 119 | + if args.train_ratio + args.test_ratio < 1.0: |
| 120 | + write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) |
| 121 | + write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) |
| 122 | + |
| 123 | +def read_list(path_in): |
| 124 | + """Reads the .lst file and generates corresponding iterator. |
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + path_in: string |
| 128 | + Returns |
| 129 | + ------- |
| 130 | + item iterator that contains information in .lst file |
| 131 | + """ |
| 132 | + with open(path_in) as fin: |
| 133 | + while True: |
| 134 | + line = fin.readline() |
| 135 | + if not line: |
| 136 | + break |
| 137 | + line = [i.strip() for i in line.strip().split('\t')] |
| 138 | + line_len = len(line) |
| 139 | + # check the data format of .lst file |
| 140 | + if line_len < 3: |
| 141 | + print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line)) |
| 142 | + continue |
| 143 | + try: |
| 144 | + item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] |
| 145 | + except Exception as e: |
| 146 | + print('Parsing lst met error for %s, detail: %s' % (line, e)) |
| 147 | + continue |
| 148 | + yield item |
| 149 | + |
| 150 | +def image_encode(args, i, item, q_out): |
| 151 | + """Reads, preprocesses, packs the image and put it back in output queue. |
| 152 | + Parameters |
| 153 | + ---------- |
| 154 | + args: object |
| 155 | + i: int |
| 156 | + item: list |
| 157 | + q_out: queue |
| 158 | + """ |
| 159 | + fullpath = os.path.join(args.root, item[1]) |
| 160 | + |
| 161 | + if len(item) > 3 and args.pack_label: |
| 162 | + header = mx.recordio.IRHeader(0, item[2:], item[0], 0) |
| 163 | + else: |
| 164 | + header = mx.recordio.IRHeader(0, item[2], item[0], 0) |
| 165 | + |
| 166 | + if args.pass_through: |
| 167 | + try: |
| 168 | + with open(fullpath, 'rb') as fin: |
| 169 | + img = fin.read() |
| 170 | + s = mx.recordio.pack(header, img) |
| 171 | + q_out.put((i, s, item)) |
| 172 | + except Exception as e: |
| 173 | + traceback.print_exc() |
| 174 | + print('pack_img error:', item[1], e) |
| 175 | + q_out.put((i, None, item)) |
| 176 | + return |
| 177 | + |
| 178 | + try: |
| 179 | + img = cv2.imread(fullpath, args.color) |
| 180 | + except: |
| 181 | + traceback.print_exc() |
| 182 | + print('imread error trying to load file: %s ' % fullpath) |
| 183 | + q_out.put((i, None, item)) |
| 184 | + return |
| 185 | + if img is None: |
| 186 | + print('imread read blank (None) image for file: %s' % fullpath) |
| 187 | + q_out.put((i, None, item)) |
| 188 | + return |
| 189 | + if args.center_crop: |
| 190 | + if img.shape[0] > img.shape[1]: |
| 191 | + margin = (img.shape[0] - img.shape[1]) // 2 |
| 192 | + img = img[margin:margin + img.shape[1], :] |
| 193 | + else: |
| 194 | + margin = (img.shape[1] - img.shape[0]) // 2 |
| 195 | + img = img[:, margin:margin + img.shape[0]] |
| 196 | + if args.resize: |
| 197 | + if img.shape[0] > img.shape[1]: |
| 198 | + newsize = (args.resize, img.shape[0] * args.resize // img.shape[1]) |
| 199 | + else: |
| 200 | + newsize = (img.shape[1] * args.resize // img.shape[0], args.resize) |
| 201 | + img = cv2.resize(img, newsize) |
| 202 | + |
| 203 | + try: |
| 204 | + s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) |
| 205 | + q_out.put((i, s, item)) |
| 206 | + except Exception as e: |
| 207 | + traceback.print_exc() |
| 208 | + print('pack_img error on file: %s' % fullpath, e) |
| 209 | + q_out.put((i, None, item)) |
| 210 | + return |
| 211 | + |
| 212 | +def read_worker(args, q_in, q_out): |
| 213 | + """Function that will be spawned to fetch the image |
| 214 | + from the input queue and put it back to output queue. |
| 215 | + Parameters |
| 216 | + ---------- |
| 217 | + args: object |
| 218 | + q_in: queue |
| 219 | + q_out: queue |
| 220 | + """ |
| 221 | + while True: |
| 222 | + deq = q_in.get() |
| 223 | + if deq is None: |
| 224 | + break |
| 225 | + i, item = deq |
| 226 | + image_encode(args, i, item, q_out) |
| 227 | + |
| 228 | +def write_worker(q_out, fname, working_dir): |
| 229 | + """Function that will be spawned to fetch processed image |
| 230 | + from the output queue and write to the .rec file. |
| 231 | + Parameters |
| 232 | + ---------- |
| 233 | + q_out: queue |
| 234 | + fname: string |
| 235 | + working_dir: string |
| 236 | + """ |
| 237 | + pre_time = time.time() |
| 238 | + count = 0 |
| 239 | + fname = os.path.basename(fname) |
| 240 | + fname_rec = os.path.splitext(fname)[0] + '.rec' |
| 241 | + fname_idx = os.path.splitext(fname)[0] + '.idx' |
| 242 | + record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), |
| 243 | + os.path.join(working_dir, fname_rec), 'w') |
| 244 | + buf = {} |
| 245 | + more = True |
| 246 | + while more: |
| 247 | + deq = q_out.get() |
| 248 | + if deq is not None: |
| 249 | + i, s, item = deq |
| 250 | + buf[i] = (s, item) |
| 251 | + else: |
| 252 | + more = False |
| 253 | + while count in buf: |
| 254 | + s, item = buf[count] |
| 255 | + del buf[count] |
| 256 | + if s is not None: |
| 257 | + record.write_idx(item[0], s) |
| 258 | + |
| 259 | + if count % 1000 == 0: |
| 260 | + cur_time = time.time() |
| 261 | + print('time:', cur_time - pre_time, ' count:', count) |
| 262 | + pre_time = cur_time |
| 263 | + count += 1 |
| 264 | + |
| 265 | +def parse_args(): |
| 266 | + """Defines all arguments. |
| 267 | + Returns |
| 268 | + ------- |
| 269 | + args object that contains all the params |
| 270 | + """ |
| 271 | + parser = argparse.ArgumentParser( |
| 272 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 273 | + description='Create an image list or \ |
| 274 | + make a record database by reading from an image list') |
| 275 | + parser.add_argument('prefix', help='prefix of input/output lst and rec files.') |
| 276 | + parser.add_argument('root', help='path to folder containing images.') |
| 277 | + |
| 278 | + cgroup = parser.add_argument_group('Options for creating image lists') |
| 279 | + cgroup.add_argument('--list', action='store_true', |
| 280 | + help='If this is set im2rec will create image list(s) by traversing root folder\ |
| 281 | + and output to <prefix>.lst.\ |
| 282 | + Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec') |
| 283 | + cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'], |
| 284 | + help='list of acceptable image extensions.') |
| 285 | + cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') |
| 286 | + cgroup.add_argument('--train-ratio', type=float, default=1.0, |
| 287 | + help='Ratio of images to use for training.') |
| 288 | + cgroup.add_argument('--test-ratio', type=float, default=0, |
| 289 | + help='Ratio of images to use for testing.') |
| 290 | + cgroup.add_argument('--recursive', action='store_true', |
| 291 | + help='If true recursively walk through subdirs and assign an unique label\ |
| 292 | + to images in each folder. Otherwise only include images in the root folder\ |
| 293 | + and give them label 0.') |
| 294 | + cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false', |
| 295 | + help='If this is passed, \ |
| 296 | + im2rec will not randomize the image order in <prefix>.lst') |
| 297 | + rgroup = parser.add_argument_group('Options for creating database') |
| 298 | + rgroup.add_argument('--pass-through', action='store_true', |
| 299 | + help='whether to skip transformation and save image as is') |
| 300 | + rgroup.add_argument('--resize', type=int, default=0, |
| 301 | + help='resize the shorter edge of image to the newsize, original images will\ |
| 302 | + be packed by default.') |
| 303 | + rgroup.add_argument('--center-crop', action='store_true', |
| 304 | + help='specify whether to crop the center image to make it rectangular.') |
| 305 | + rgroup.add_argument('--quality', type=int, default=95, |
| 306 | + help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') |
| 307 | + rgroup.add_argument('--num-thread', type=int, default=1, |
| 308 | + help='number of thread to use for encoding. order of images will be different\ |
| 309 | + from the input list if >1. the input list will be modified to match the\ |
| 310 | + resulting order.') |
| 311 | + rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], |
| 312 | + help='specify the color mode of the loaded image.\ |
| 313 | + 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ |
| 314 | + 0: Loads image in grayscale mode.\ |
| 315 | + -1:Loads image as such including alpha channel.') |
| 316 | + rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], |
| 317 | + help='specify the encoding of the images.') |
| 318 | + rgroup.add_argument('--pack-label', action='store_true', |
| 319 | + help='Whether to also pack multi dimensional label in the record file') |
| 320 | + args = parser.parse_args() |
| 321 | + args.prefix = os.path.abspath(args.prefix) |
| 322 | + args.root = os.path.abspath(args.root) |
| 323 | + return args |
| 324 | + |
| 325 | +if __name__ == '__main__': |
| 326 | + args = parse_args() |
| 327 | + # if the '--list' is used, it generates .lst file |
| 328 | + if args.list: |
| 329 | + make_list(args) |
| 330 | + # otherwise read .lst file to generates .rec file |
| 331 | + else: |
| 332 | + if os.path.isdir(args.prefix): |
| 333 | + working_dir = args.prefix |
| 334 | + else: |
| 335 | + working_dir = os.path.dirname(args.prefix) |
| 336 | + files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir) |
| 337 | + if os.path.isfile(os.path.join(working_dir, fname))] |
| 338 | + count = 0 |
| 339 | + for fname in files: |
| 340 | + if fname.startswith(args.prefix) and fname.endswith('.lst'): |
| 341 | + print('Creating .rec file from', fname, 'in', working_dir) |
| 342 | + count += 1 |
| 343 | + image_list = read_list(fname) |
| 344 | + # -- write_record -- # |
| 345 | + if args.num_thread > 1 and multiprocessing is not None: |
| 346 | + q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)] |
| 347 | + q_out = multiprocessing.Queue(1024) |
| 348 | + # define the process |
| 349 | + read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \ |
| 350 | + for i in range(args.num_thread)] |
| 351 | + # process images with num_thread process |
| 352 | + for p in read_process: |
| 353 | + p.start() |
| 354 | + # only use one process to write .rec to avoid race-condtion |
| 355 | + write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) |
| 356 | + write_process.start() |
| 357 | + # put the image list into input queue |
| 358 | + for i, item in enumerate(image_list): |
| 359 | + q_in[i % len(q_in)].put((i, item)) |
| 360 | + for q in q_in: |
| 361 | + q.put(None) |
| 362 | + for p in read_process: |
| 363 | + p.join() |
| 364 | + |
| 365 | + q_out.put(None) |
| 366 | + write_process.join() |
| 367 | + else: |
| 368 | + print('multiprocessing not available, fall back to single threaded encoding') |
| 369 | + try: |
| 370 | + import Queue as queue |
| 371 | + except ImportError: |
| 372 | + import queue |
| 373 | + q_out = queue.Queue() |
| 374 | + fname = os.path.basename(fname) |
| 375 | + fname_rec = os.path.splitext(fname)[0] + '.rec' |
| 376 | + fname_idx = os.path.splitext(fname)[0] + '.idx' |
| 377 | + record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), |
| 378 | + os.path.join(working_dir, fname_rec), 'w') |
| 379 | + cnt = 0 |
| 380 | + pre_time = time.time() |
| 381 | + for i, item in enumerate(image_list): |
| 382 | + image_encode(args, i, item, q_out) |
| 383 | + if q_out.empty(): |
| 384 | + continue |
| 385 | + _, s, _ = q_out.get() |
| 386 | + record.write_idx(item[0], s) |
| 387 | + if cnt % 1000 == 0: |
| 388 | + cur_time = time.time() |
| 389 | + print('time:', cur_time - pre_time, ' count:', cnt) |
| 390 | + pre_time = cur_time |
| 391 | + cnt += 1 |
| 392 | + if not count: |
| 393 | + print('Did not find and list file with prefix %s'%args.prefix) |
0 commit comments