|
| 1 | +## Fusion Model classification |
| 2 | +# Arun Aniyan |
| 3 | +# SKA SA/ RATT |
| 4 | +# arun@ska.ac.za |
| 5 | +# 18-02-17 |
| 6 | + |
| 7 | +# Input can be either fits image or jpg/png |
| 8 | + |
| 9 | +# Import necessary stuff |
| 10 | + |
| 11 | +import sys |
| 12 | +import os |
| 13 | +import time |
| 14 | +import datetime |
| 15 | +from collections import Counter |
| 16 | + |
| 17 | +import PIL.Image |
| 18 | +import numpy as np |
| 19 | +import scipy.misc |
| 20 | +from google.protobuf import text_format |
| 21 | + |
| 22 | +os.environ['GLOG_minloglevel'] = '2' # Suppress most caffe output |
| 23 | +import caffe |
| 24 | +from caffe.proto import caffe_pb2 |
| 25 | + |
| 26 | +from astropy.io import fits |
| 27 | +from astropy.stats import sigma_clipped_stats |
| 28 | +from scipy.misc import imsave |
| 29 | +from skimage.transform import resize |
| 30 | +from skimage.color import rgb2gray |
| 31 | + |
| 32 | + |
| 33 | +# Function definitions |
| 34 | + |
| 35 | +# Load model |
| 36 | +def get_net(caffemodel, deploy_file, use_gpu=True): |
| 37 | + """ |
| 38 | + Returns an instance of caffe.Net |
| 39 | +
|
| 40 | + Arguments: |
| 41 | + caffemodel -- path to a .caffemodel file |
| 42 | + deploy_file -- path to a .prototxt file |
| 43 | +
|
| 44 | + Keyword arguments: |
| 45 | + use_gpu -- if True, use the GPU for inference |
| 46 | + """ |
| 47 | + #if use_gpu: |
| 48 | + # caffe.set_mode_gpu() |
| 49 | + caffe.set_mode_cpu() |
| 50 | + |
| 51 | + # load a new model |
| 52 | + return caffe.Net(deploy_file, caffemodel, caffe.TEST) |
| 53 | + |
| 54 | +# Transformer function to perform image transformation |
| 55 | +def get_transformer(deploy_file, mean_file=None): |
| 56 | + """ |
| 57 | + Returns an instance of caffe.io.Transformer |
| 58 | +
|
| 59 | + Arguments: |
| 60 | + deploy_file -- path to a .prototxt file |
| 61 | +
|
| 62 | + Keyword arguments: |
| 63 | + mean_file -- path to a .binaryproto file (optional) |
| 64 | + """ |
| 65 | + network = caffe_pb2.NetParameter() |
| 66 | + with open(deploy_file) as infile: |
| 67 | + text_format.Merge(infile.read(), network) |
| 68 | + |
| 69 | + if network.input_shape: |
| 70 | + |
| 71 | + dims = network.input_shape[0].dim |
| 72 | + else: |
| 73 | + dims = network.input_dim[:4] |
| 74 | + |
| 75 | + |
| 76 | + #dims = network.input_dim |
| 77 | + |
| 78 | + t = caffe.io.Transformer( |
| 79 | + inputs = {'data': dims} |
| 80 | + ) |
| 81 | + t.set_transpose('data', (2,0,1)) # transpose to (channels, height, width) |
| 82 | + |
| 83 | + # color images |
| 84 | + if dims[1] == 3: |
| 85 | + # channel swap |
| 86 | + t.set_channel_swap('data', (2,1,0)) |
| 87 | + |
| 88 | + if mean_file: |
| 89 | + # set mean pixel |
| 90 | + with open(mean_file) as infile: |
| 91 | + blob = caffe_pb2.BlobProto() |
| 92 | + blob.MergeFromString(infile.read()) |
| 93 | + if blob.HasField('shape'): |
| 94 | + blob_dims = blob.shape |
| 95 | + assert len(blob_dims) == 4, 'Shape should have 4 dimensions - shape is "%s"' % blob.shape |
| 96 | + elif blob.HasField('num') and blob.HasField('channels') and \ |
| 97 | + blob.HasField('height') and blob.HasField('width'): |
| 98 | + blob_dims = (blob.num, blob.channels, blob.height, blob.width) |
| 99 | + else: |
| 100 | + raise ValueError('blob does not provide shape or 4d dimensions') |
| 101 | + pixel = np.reshape(blob.data, blob_dims[1:]).mean(1).mean(1) |
| 102 | + t.set_mean('data', pixel) |
| 103 | + |
| 104 | + return t |
| 105 | + |
| 106 | +# Load image to caffe |
| 107 | +def load_image(path, height, width, mode='RGB'): |
| 108 | + """ |
| 109 | + Load an image from disk |
| 110 | +
|
| 111 | + Returns an np.ndarray (channels x width x height) |
| 112 | +
|
| 113 | + Arguments: |
| 114 | + path -- path to an image on disk |
| 115 | + width -- resize dimension |
| 116 | + height -- resize dimension |
| 117 | +
|
| 118 | + Keyword arguments: |
| 119 | + mode -- the PIL mode that the image should be converted to |
| 120 | + (RGB for color or L for grayscale) |
| 121 | + """ |
| 122 | + |
| 123 | + image = PIL.Image.open(path) |
| 124 | + image = image.convert(mode) |
| 125 | + image = np.array(image) |
| 126 | + # squash |
| 127 | + image = scipy.misc.imresize(image, (height, width), 'bilinear') |
| 128 | + return image |
| 129 | + |
| 130 | +# Forward pass of input through the network |
| 131 | +def forward_pass(images, net, transformer, batch_size=1): |
| 132 | + """ |
| 133 | + Returns scores for each image as an np.ndarray (nImages x nClasses) |
| 134 | +
|
| 135 | + Arguments: |
| 136 | + images -- a list of np.ndarrays |
| 137 | + net -- a caffe.Net |
| 138 | + transformer -- a caffe.io.Transformer |
| 139 | +
|
| 140 | + Keyword arguments: |
| 141 | + batch_size -- how many images can be processed at once |
| 142 | + (a high value may result in out-of-memory errors) |
| 143 | + """ |
| 144 | + caffe_images = [] |
| 145 | + for image in images: |
| 146 | + if image.ndim == 2: |
| 147 | + caffe_images.append(image[:,:,np.newaxis]) |
| 148 | + else: |
| 149 | + caffe_images.append(image) |
| 150 | + |
| 151 | + caffe_images = np.array(caffe_images) |
| 152 | + |
| 153 | + dims = transformer.inputs['data'][1:] |
| 154 | + |
| 155 | + scores = None |
| 156 | + for chunk in [caffe_images[x:x+batch_size] for x in xrange(0, len(caffe_images), batch_size)]: |
| 157 | + new_shape = (len(chunk),) + tuple(dims) |
| 158 | + if net.blobs['data'].data.shape != new_shape: |
| 159 | + net.blobs['data'].reshape(*new_shape) |
| 160 | + for index, image in enumerate(chunk): |
| 161 | + image_data = transformer.preprocess('data', image) |
| 162 | + net.blobs['data'].data[index] = image_data |
| 163 | + output = net.forward()[net.outputs[-1]] |
| 164 | + if scores is None: |
| 165 | + scores = output |
| 166 | + else: |
| 167 | + scores = np.vstack((scores, output)) |
| 168 | + #print 'Processed %s/%s images ...' % (len(scores), len(caffe_images)) |
| 169 | + |
| 170 | + return scores |
| 171 | + |
| 172 | +# Resolve labels |
| 173 | +def read_labels(labels_file): |
| 174 | + """ |
| 175 | + Returns a list of strings |
| 176 | +
|
| 177 | + Arguments: |
| 178 | + labels_file -- path to a .txt file |
| 179 | + """ |
| 180 | + if not labels_file: |
| 181 | + print 'WARNING: No labels file provided. Results will be difficult to interpret.' |
| 182 | + return None |
| 183 | + |
| 184 | + labels = [] |
| 185 | + with open(labels_file) as infile: |
| 186 | + for line in infile: |
| 187 | + label = line.strip() |
| 188 | + if label: |
| 189 | + labels.append(label) |
| 190 | + assert len(labels), 'No labels found' |
| 191 | + return labels |
| 192 | + |
| 193 | + |
| 194 | +# Decide class based on threshold |
| 195 | +def decide(classification): |
| 196 | + lbl = [] |
| 197 | + conf = [] |
| 198 | + for label, confidence in classification: |
| 199 | + lbl.append(label) |
| 200 | + conf.append(confidence) |
| 201 | + idx = np.argmax(conf) |
| 202 | + |
| 203 | + return lbl[idx],conf[idx] |
| 204 | + |
| 205 | +# Perform Single classification |
| 206 | +def classify(caffemodel, deploy_file, image_files, |
| 207 | + mean_file=None, labels_file=None, use_gpu=True): |
| 208 | + """ |
| 209 | + Classify some images against a Caffe model and print the results |
| 210 | +
|
| 211 | + Arguments: |
| 212 | + caffemodel -- path to a .caffemodel |
| 213 | + deploy_file -- path to a .prototxt |
| 214 | + image_files -- list of paths to images |
| 215 | +
|
| 216 | + Keyword arguments: |
| 217 | + mean_file -- path to a .binaryproto |
| 218 | + labels_file path to a .txt file |
| 219 | + use_gpu -- if True, run inference on the GPU |
| 220 | + """ |
| 221 | + # Load the model and images |
| 222 | + net = get_net(caffemodel, deploy_file, use_gpu) |
| 223 | + transformer = get_transformer(deploy_file, mean_file) |
| 224 | + _, channels, height, width = transformer.inputs['data'] |
| 225 | + if channels == 3: |
| 226 | + mode = 'RGB' |
| 227 | + elif channels == 1: |
| 228 | + mode = 'L' |
| 229 | + else: |
| 230 | + raise ValueError('Invalid number for channels: %s' % channels) |
| 231 | + images = [load_image(image_file, height, width, mode) for image_file in image_files] |
| 232 | + labels = read_labels(labels_file) |
| 233 | + |
| 234 | + # Classify the image |
| 235 | + classify_start_time = time.time() |
| 236 | + scores = forward_pass(images, net, transformer) |
| 237 | + #print 'Classification took %s seconds.' % (time.time() - classify_start_time,) |
| 238 | + |
| 239 | + ### Process the results |
| 240 | + |
| 241 | + indices = (-scores).argsort()[:, :5] # take top 5 results |
| 242 | + classifications = [] |
| 243 | + for image_index, index_list in enumerate(indices): |
| 244 | + result = [] |
| 245 | + for i in index_list: |
| 246 | + # 'i' is a category in labels and also an index into scores |
| 247 | + if labels is None: |
| 248 | + label = 'Class #%s' % i |
| 249 | + else: |
| 250 | + label = labels[i] |
| 251 | + result.append((label, round(100.0*scores[image_index, i],4))) |
| 252 | + classifications.append(result) |
| 253 | + |
| 254 | + for index, classification in enumerate(classifications): |
| 255 | + |
| 256 | + #print '{:-^80}'.format(' Prediction for %s ' % image_files[index]) |
| 257 | + |
| 258 | + lbl, conf = decide(classification) |
| 259 | + |
| 260 | + return lbl,conf |
| 261 | + |
| 262 | + |
| 263 | +# Fusion decision model |
| 264 | +def vote(ypreds,probs,thresh): |
| 265 | + # Find the repeating class among the three models |
| 266 | + high_vote = [item for item, count in Counter(ypreds).items() if count >1] |
| 267 | + |
| 268 | + if high_vote !=[]: |
| 269 | + final_class = high_vote[0] |
| 270 | + # Check if their probability is greater than 60 |
| 271 | + idx = np.where(np.array(ypreds)==final_class)[0] |
| 272 | + if float(probs[idx[0]]) > thresh or float(probs[idx[1]]) > thresh: |
| 273 | + final_classification = final_class |
| 274 | + final_probability = (float(probs[idx[0]])+float(probs[idx[1]]))/2.0 |
| 275 | + else: |
| 276 | + final_classification = final_class +'?' |
| 277 | + final_probability = min(float(probs[idx[0]]),float(probs[idx[1]])) |
| 278 | + else: |
| 279 | + final_classification = 'Strange' |
| 280 | + final_probability = 0 |
| 281 | + |
| 282 | + return final_classification,final_probability |
| 283 | + |
| 284 | +def checkfits(filename): |
| 285 | + if filename.rsplit('.',1)[1] == 'fits': |
| 286 | + image = fits2jpg(filename) |
| 287 | + return image |
| 288 | + |
| 289 | + |
| 290 | +# Clipper function |
| 291 | +def clip(data,lim): |
| 292 | + data[data<lim] = 0.0 |
| 293 | + return data |
| 294 | + |
| 295 | +# Convert fits image to png |
| 296 | +def fits2jpg(fname): |
| 297 | + hdu_list = fits.open(fname) |
| 298 | + image = hdu_list[0].data |
| 299 | + image = np.squeeze(image) |
| 300 | + img = np.copy(image) |
| 301 | + idx = np.isnan(img) |
| 302 | + img[idx] = 0 |
| 303 | + img_clip = np.flipud(img) |
| 304 | + sigma = 3.0 |
| 305 | + # Estimate stats |
| 306 | + mean, median, std = sigma_clipped_stats(img_clip, sigma=sigma, iters=10) |
| 307 | + # Clip off n sigma points |
| 308 | + img_clip = clip(img_clip,std*sigma) |
| 309 | + if img_clip.shape[0] !=150 or img_clip.shape[1] !=150: |
| 310 | + img_clip = resize(img_clip, (150,150)) |
| 311 | + #img_clip = rgb2gray(img_clip) |
| 312 | + |
| 313 | + outfile = fname[0:-5] +'.png' |
| 314 | + imsave(outfile, img_clip) |
| 315 | + return img_clip,outfile |
| 316 | + |
| 317 | + |
| 318 | + |
| 319 | + |
| 320 | +# Do the fusion classification |
| 321 | +def fusion_classify(image_file): |
| 322 | + |
| 323 | + # Change location of files appropriately |
| 324 | + models = ['Prototxt/fr1vsfr2.prototxt','Prototxt/fr1vsbent.prototxt','Prototxt/fr2vsbent.prototxt'] |
| 325 | + nets = ['Models/fr1vsfr2.caffemodel','Models/fr1vsbent.caffemodel','Models/fr2vsbent.caffemodel'] |
| 326 | + labels = ['Labels/fr1vsfr2-label.txt','Labels/fr1vsbent-label.txt','Labels/fr2vsbent-label.txt'] |
| 327 | + |
| 328 | + thresh = 90 # Decision cut of to make the final classification |
| 329 | + |
| 330 | + ypreds = [] |
| 331 | + probs = [] |
| 332 | + |
| 333 | + |
| 334 | + if image_file.rsplit('.',1)[1] == 'fits': |
| 335 | + image, outfile = fits2jpg(image_file) |
| 336 | + image_file = outfile |
| 337 | + |
| 338 | + |
| 339 | + for i in range(3): |
| 340 | + |
| 341 | + |
| 342 | + lbl, conf = classify(nets[i], models[i],[image_file],labels_file=labels[i]) |
| 343 | + |
| 344 | + ypreds.append(lbl) |
| 345 | + probs.append(conf) |
| 346 | + |
| 347 | + classlabel,probability = vote(ypreds, probs, thresh) |
| 348 | + |
| 349 | + probability = round(probability,2) |
| 350 | + |
| 351 | + return classlabel,probability |
| 352 | + |
| 353 | + |
| 354 | + |
| 355 | +if __name__ == '__main__': |
| 356 | + |
| 357 | + script_start_time = time.time() |
| 358 | + |
| 359 | + arg = sys.argv |
| 360 | + image_file = arg[1] |
| 361 | + |
| 362 | + # Extract filename without extension and root path |
| 363 | + filename = os.path.basename(image_file) |
| 364 | + filename = os.path.splitext(filename)[0] |
| 365 | + |
| 366 | + classlabel, probability = fusion_classify(image_file) |
| 367 | + |
| 368 | + print '%s is classified as %s with %.2f%% confidence.' %(filename,classlabel,probability) |
| 369 | + print 'Script took %s seconds.' % (time.time() - script_start_time,) |
| 370 | + |
| 371 | + |
| 372 | + |
| 373 | + |
| 374 | + |
| 375 | + |
0 commit comments