-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_cell.py
108 lines (94 loc) · 4.18 KB
/
test_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import print_function
import sys
import os
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.autograd import Variable
from data import CELL_ROOT, CELL_CLASS as labelmap
# from PIL import Image
from data import CELLAnnotationTransform, CELLDetection, BaseTransform, CELL_CLASS, cell
import torch.utils.data as data
from ssd import build_ssd
from utils.draw_boundingbox import draw_bound
import cv2
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection')
parser.add_argument('--trained_model', default='weights/LargeData512/conv2_2.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--save_folder', default='eval/', type=str,
help='Dir to save results')
parser.add_argument('--visual_threshold', default=0.4, type=float,
help='Final confidence threshold')
parser.add_argument('--cuda', default=False, type=bool,
help='Use cuda to train model')
parser.add_argument('--cell_root', default=CELL_ROOT, help='Location of CELL root directory')
parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks")
args = parser.parse_args()
if args.cuda and torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
def test_net(save_folder, net, cuda, testset, transform, thresh):
# dump predictions and assoc. ground truth to text file for now
filename = save_folder+'test1.txt'
num_images = len(testset)
open(filename, 'w').close() # erase file contents
for i in range(num_images):
print('Testing image {:d}/{:d}....'.format(i+1, num_images))
img = testset.pull_image(i) # image e.g.(500, 353, 3)
img_id, annotation = testset.pull_anno(i)
x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1) # torch.Size([3, 300, 300])
# x = Variable(x.unsqueeze(0))
x = x.unsqueeze(0)
with open(filename, mode='a') as f:
f.write('\nGROUND TRUTH FOR: '+img_id+'\n')
for box in annotation:
f.write('label: '+' || '.join(str(b) for b in box)+'\n')
if cuda:
x = x.cuda()
y = net(x) # forward pass
detections = y.data # detections: torch.Size([1, 21, 200, 5])
# scale each detection back up to the image
scale = torch.Tensor([img.shape[1], img.shape[0],
img.shape[1], img.shape[0]])
pred_num = 0
for i in range(detections.size(1)):
j = 0
while detections[0, i, j, 0] >= thresh:
if pred_num == 0:
with open(filename, mode='a') as f:
f.write('PREDICTIONS: '+'\n')
score = detections[0, i, j, 0]
label_name = labelmap[i-1]
pt = (detections[0, i, j, 1:]*scale).cpu().numpy()
coords = (pt[0], pt[1], pt[2], pt[3])
pred_num += 1
with open(filename, mode='a') as f:
f.write(str(pred_num)+' label: '+label_name+' score: ' +
str(score) + ' '+' || '.join(str(c) for c in coords) + '\n')
j += 1
draw_bound(img, coords, label_name, score)
cv2.imwrite(args.save_folder + img_id + '_anno.jpg', img)
ssd_size = 512
def test_cell():
# load net
num_classes = cell['num_classes'] # +1 background
net = build_ssd('test', ssd_size, num_classes) # initialize SSD
net.load_state_dict(torch.load(args.trained_model))
net.eval()
print('Finished loading model!')
# load data
testset = CELLDetection(args.cell_root, 'test', None, CELLAnnotationTransform())
if args.cuda:
net = net.cuda()
cudnn.benchmark = True
# evaluation
test_net(args.save_folder, net, args.cuda, testset,
BaseTransform(net.size, (104, 117, 123)),
thresh=args.visual_threshold)
if __name__ == '__main__':
test_cell()