Skip to content

Commit eebde02

Browse files
committed
Code to execute
1 parent 3b233e0 commit eebde02

File tree

1 file changed

+375
-0
lines changed

1 file changed

+375
-0
lines changed

Codes/fusion-classify.py

+375
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
## Fusion Model classification
2+
# Arun Aniyan
3+
# SKA SA/ RATT
4+
# arun@ska.ac.za
5+
# 18-02-17
6+
7+
# Input can be either fits image or jpg/png
8+
9+
# Import necessary stuff
10+
11+
import sys
12+
import os
13+
import time
14+
import datetime
15+
from collections import Counter
16+
17+
import PIL.Image
18+
import numpy as np
19+
import scipy.misc
20+
from google.protobuf import text_format
21+
22+
os.environ['GLOG_minloglevel'] = '2' # Suppress most caffe output
23+
import caffe
24+
from caffe.proto import caffe_pb2
25+
26+
from astropy.io import fits
27+
from astropy.stats import sigma_clipped_stats
28+
from scipy.misc import imsave
29+
from skimage.transform import resize
30+
from skimage.color import rgb2gray
31+
32+
33+
# Function definitions
34+
35+
# Load model
36+
def get_net(caffemodel, deploy_file, use_gpu=True):
37+
"""
38+
Returns an instance of caffe.Net
39+
40+
Arguments:
41+
caffemodel -- path to a .caffemodel file
42+
deploy_file -- path to a .prototxt file
43+
44+
Keyword arguments:
45+
use_gpu -- if True, use the GPU for inference
46+
"""
47+
#if use_gpu:
48+
# caffe.set_mode_gpu()
49+
caffe.set_mode_cpu()
50+
51+
# load a new model
52+
return caffe.Net(deploy_file, caffemodel, caffe.TEST)
53+
54+
# Transformer function to perform image transformation
55+
def get_transformer(deploy_file, mean_file=None):
56+
"""
57+
Returns an instance of caffe.io.Transformer
58+
59+
Arguments:
60+
deploy_file -- path to a .prototxt file
61+
62+
Keyword arguments:
63+
mean_file -- path to a .binaryproto file (optional)
64+
"""
65+
network = caffe_pb2.NetParameter()
66+
with open(deploy_file) as infile:
67+
text_format.Merge(infile.read(), network)
68+
69+
if network.input_shape:
70+
71+
dims = network.input_shape[0].dim
72+
else:
73+
dims = network.input_dim[:4]
74+
75+
76+
#dims = network.input_dim
77+
78+
t = caffe.io.Transformer(
79+
inputs = {'data': dims}
80+
)
81+
t.set_transpose('data', (2,0,1)) # transpose to (channels, height, width)
82+
83+
# color images
84+
if dims[1] == 3:
85+
# channel swap
86+
t.set_channel_swap('data', (2,1,0))
87+
88+
if mean_file:
89+
# set mean pixel
90+
with open(mean_file) as infile:
91+
blob = caffe_pb2.BlobProto()
92+
blob.MergeFromString(infile.read())
93+
if blob.HasField('shape'):
94+
blob_dims = blob.shape
95+
assert len(blob_dims) == 4, 'Shape should have 4 dimensions - shape is "%s"' % blob.shape
96+
elif blob.HasField('num') and blob.HasField('channels') and \
97+
blob.HasField('height') and blob.HasField('width'):
98+
blob_dims = (blob.num, blob.channels, blob.height, blob.width)
99+
else:
100+
raise ValueError('blob does not provide shape or 4d dimensions')
101+
pixel = np.reshape(blob.data, blob_dims[1:]).mean(1).mean(1)
102+
t.set_mean('data', pixel)
103+
104+
return t
105+
106+
# Load image to caffe
107+
def load_image(path, height, width, mode='RGB'):
108+
"""
109+
Load an image from disk
110+
111+
Returns an np.ndarray (channels x width x height)
112+
113+
Arguments:
114+
path -- path to an image on disk
115+
width -- resize dimension
116+
height -- resize dimension
117+
118+
Keyword arguments:
119+
mode -- the PIL mode that the image should be converted to
120+
(RGB for color or L for grayscale)
121+
"""
122+
123+
image = PIL.Image.open(path)
124+
image = image.convert(mode)
125+
image = np.array(image)
126+
# squash
127+
image = scipy.misc.imresize(image, (height, width), 'bilinear')
128+
return image
129+
130+
# Forward pass of input through the network
131+
def forward_pass(images, net, transformer, batch_size=1):
132+
"""
133+
Returns scores for each image as an np.ndarray (nImages x nClasses)
134+
135+
Arguments:
136+
images -- a list of np.ndarrays
137+
net -- a caffe.Net
138+
transformer -- a caffe.io.Transformer
139+
140+
Keyword arguments:
141+
batch_size -- how many images can be processed at once
142+
(a high value may result in out-of-memory errors)
143+
"""
144+
caffe_images = []
145+
for image in images:
146+
if image.ndim == 2:
147+
caffe_images.append(image[:,:,np.newaxis])
148+
else:
149+
caffe_images.append(image)
150+
151+
caffe_images = np.array(caffe_images)
152+
153+
dims = transformer.inputs['data'][1:]
154+
155+
scores = None
156+
for chunk in [caffe_images[x:x+batch_size] for x in xrange(0, len(caffe_images), batch_size)]:
157+
new_shape = (len(chunk),) + tuple(dims)
158+
if net.blobs['data'].data.shape != new_shape:
159+
net.blobs['data'].reshape(*new_shape)
160+
for index, image in enumerate(chunk):
161+
image_data = transformer.preprocess('data', image)
162+
net.blobs['data'].data[index] = image_data
163+
output = net.forward()[net.outputs[-1]]
164+
if scores is None:
165+
scores = output
166+
else:
167+
scores = np.vstack((scores, output))
168+
#print 'Processed %s/%s images ...' % (len(scores), len(caffe_images))
169+
170+
return scores
171+
172+
# Resolve labels
173+
def read_labels(labels_file):
174+
"""
175+
Returns a list of strings
176+
177+
Arguments:
178+
labels_file -- path to a .txt file
179+
"""
180+
if not labels_file:
181+
print 'WARNING: No labels file provided. Results will be difficult to interpret.'
182+
return None
183+
184+
labels = []
185+
with open(labels_file) as infile:
186+
for line in infile:
187+
label = line.strip()
188+
if label:
189+
labels.append(label)
190+
assert len(labels), 'No labels found'
191+
return labels
192+
193+
194+
# Decide class based on threshold
195+
def decide(classification):
196+
lbl = []
197+
conf = []
198+
for label, confidence in classification:
199+
lbl.append(label)
200+
conf.append(confidence)
201+
idx = np.argmax(conf)
202+
203+
return lbl[idx],conf[idx]
204+
205+
# Perform Single classification
206+
def classify(caffemodel, deploy_file, image_files,
207+
mean_file=None, labels_file=None, use_gpu=True):
208+
"""
209+
Classify some images against a Caffe model and print the results
210+
211+
Arguments:
212+
caffemodel -- path to a .caffemodel
213+
deploy_file -- path to a .prototxt
214+
image_files -- list of paths to images
215+
216+
Keyword arguments:
217+
mean_file -- path to a .binaryproto
218+
labels_file path to a .txt file
219+
use_gpu -- if True, run inference on the GPU
220+
"""
221+
# Load the model and images
222+
net = get_net(caffemodel, deploy_file, use_gpu)
223+
transformer = get_transformer(deploy_file, mean_file)
224+
_, channels, height, width = transformer.inputs['data']
225+
if channels == 3:
226+
mode = 'RGB'
227+
elif channels == 1:
228+
mode = 'L'
229+
else:
230+
raise ValueError('Invalid number for channels: %s' % channels)
231+
images = [load_image(image_file, height, width, mode) for image_file in image_files]
232+
labels = read_labels(labels_file)
233+
234+
# Classify the image
235+
classify_start_time = time.time()
236+
scores = forward_pass(images, net, transformer)
237+
#print 'Classification took %s seconds.' % (time.time() - classify_start_time,)
238+
239+
### Process the results
240+
241+
indices = (-scores).argsort()[:, :5] # take top 5 results
242+
classifications = []
243+
for image_index, index_list in enumerate(indices):
244+
result = []
245+
for i in index_list:
246+
# 'i' is a category in labels and also an index into scores
247+
if labels is None:
248+
label = 'Class #%s' % i
249+
else:
250+
label = labels[i]
251+
result.append((label, round(100.0*scores[image_index, i],4)))
252+
classifications.append(result)
253+
254+
for index, classification in enumerate(classifications):
255+
256+
#print '{:-^80}'.format(' Prediction for %s ' % image_files[index])
257+
258+
lbl, conf = decide(classification)
259+
260+
return lbl,conf
261+
262+
263+
# Fusion decision model
264+
def vote(ypreds,probs,thresh):
265+
# Find the repeating class among the three models
266+
high_vote = [item for item, count in Counter(ypreds).items() if count >1]
267+
268+
if high_vote !=[]:
269+
final_class = high_vote[0]
270+
# Check if their probability is greater than 60
271+
idx = np.where(np.array(ypreds)==final_class)[0]
272+
if float(probs[idx[0]]) > thresh or float(probs[idx[1]]) > thresh:
273+
final_classification = final_class
274+
final_probability = (float(probs[idx[0]])+float(probs[idx[1]]))/2.0
275+
else:
276+
final_classification = final_class +'?'
277+
final_probability = min(float(probs[idx[0]]),float(probs[idx[1]]))
278+
else:
279+
final_classification = 'Strange'
280+
final_probability = 0
281+
282+
return final_classification,final_probability
283+
284+
def checkfits(filename):
285+
if filename.rsplit('.',1)[1] == 'fits':
286+
image = fits2jpg(filename)
287+
return image
288+
289+
290+
# Clipper function
291+
def clip(data,lim):
292+
data[data<lim] = 0.0
293+
return data
294+
295+
# Convert fits image to png
296+
def fits2jpg(fname):
297+
hdu_list = fits.open(fname)
298+
image = hdu_list[0].data
299+
image = np.squeeze(image)
300+
img = np.copy(image)
301+
idx = np.isnan(img)
302+
img[idx] = 0
303+
img_clip = np.flipud(img)
304+
sigma = 3.0
305+
# Estimate stats
306+
mean, median, std = sigma_clipped_stats(img_clip, sigma=sigma, iters=10)
307+
# Clip off n sigma points
308+
img_clip = clip(img_clip,std*sigma)
309+
if img_clip.shape[0] !=150 or img_clip.shape[1] !=150:
310+
img_clip = resize(img_clip, (150,150))
311+
#img_clip = rgb2gray(img_clip)
312+
313+
outfile = fname[0:-5] +'.png'
314+
imsave(outfile, img_clip)
315+
return img_clip,outfile
316+
317+
318+
319+
320+
# Do the fusion classification
321+
def fusion_classify(image_file):
322+
323+
# Change location of files appropriately
324+
models = ['Prototxt/fr1vsfr2.prototxt','Prototxt/fr1vsbent.prototxt','Prototxt/fr2vsbent.prototxt']
325+
nets = ['Models/fr1vsfr2.caffemodel','Models/fr1vsbent.caffemodel','Models/fr2vsbent.caffemodel']
326+
labels = ['Labels/fr1vsfr2-label.txt','Labels/fr1vsbent-label.txt','Labels/fr2vsbent-label.txt']
327+
328+
thresh = 90 # Decision cut of to make the final classification
329+
330+
ypreds = []
331+
probs = []
332+
333+
334+
if image_file.rsplit('.',1)[1] == 'fits':
335+
image, outfile = fits2jpg(image_file)
336+
image_file = outfile
337+
338+
339+
for i in range(3):
340+
341+
342+
lbl, conf = classify(nets[i], models[i],[image_file],labels_file=labels[i])
343+
344+
ypreds.append(lbl)
345+
probs.append(conf)
346+
347+
classlabel,probability = vote(ypreds, probs, thresh)
348+
349+
probability = round(probability,2)
350+
351+
return classlabel,probability
352+
353+
354+
355+
if __name__ == '__main__':
356+
357+
script_start_time = time.time()
358+
359+
arg = sys.argv
360+
image_file = arg[1]
361+
362+
# Extract filename without extension and root path
363+
filename = os.path.basename(image_file)
364+
filename = os.path.splitext(filename)[0]
365+
366+
classlabel, probability = fusion_classify(image_file)
367+
368+
print '%s is classified as %s with %.2f%% confidence.' %(filename,classlabel,probability)
369+
print 'Script took %s seconds.' % (time.time() - script_start_time,)
370+
371+
372+
373+
374+
375+

0 commit comments

Comments
 (0)