-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
105 lines (91 loc) · 4.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import numpy as np
from tqdm import tqdm
import argparse
from PIL import Image
import random
import glob
from detector_cleanse import detector_cleanse
from model import FasterRCNNVGG16
from transform import preprocess
import warnings
warnings.filterwarnings("ignore")
def parse_arguments():
parser = argparse.ArgumentParser(description='Run Detector Cleanse on an image')
parser.add_argument('--n', type=int, required=True, help='Number of features to randomly select')
parser.add_argument('--m', type=float, required=True, help='Detection mean')
parser.add_argument('--delta', type=float, required=True, help='Detection threshold')
parser.add_argument('--alpha', type=float, default=0.5, help='Blending ratio')
parser.add_argument('--iouthresh', type=float, default=0.5, help='Threshold iou')
parser.add_argument('--image_path', type=str, default='images', help='Path to the image(s) to be analyzed')
parser.add_argument('--clean_feature_path', type=str ,default='clean_feature_images', help='Path to the clean_feature image folder')
parser.add_argument('--weight', type=str, required=True, help='Path to weight of the model')
return parser.parse_args()
def main():
args = parse_arguments()
print("Loading clean feature files...")
clean_feature_files = glob.glob(f'{args.clean_feature_path}/*.jpg')
selected_features = random.sample(clean_feature_files, args.n)
clean_features = [Image.open(feature_path) for feature_path in selected_features]
for i in range(len(clean_features)):
feature = clean_features[i].convert('RGB')
feature = np.asarray(feature, dtype=np.float32)
feature = feature.transpose((2, 0, 1))
feature = preprocess(feature)
clean_features[i] = feature
print("Complete")
print("Loading model...")
model = FasterRCNNVGG16(n_fg_class=20)
state_dict = torch.load(args.weight)
if 'model' in state_dict:
model.load_state_dict(state_dict['model'])
else: # legacy way, for backward compatibility
model.load_state_dict(state_dict)
print("Complete")
print("Detecting")
if 'jpg' not in args.image_path:
image_files = glob.glob(f'{args.image_path}/*.jpg')
total_clean, total_poison = 0, 0
false_accept, false_reject, success = 0, 0, 0
pbar = tqdm(image_files)
for image_file in pbar:
f = Image.open(image_file)
ori_img = f.convert('RGB')
ori_img = np.asarray(ori_img, dtype=np.float32)
ori_img = ori_img.transpose((2, 0, 1))
img = preprocess(ori_img)
poisoned, coordinates = detector_cleanse(img, model, clean_features, args.m, args.delta, args.alpha, args.iouthresh)
if "modified" in image_file:
total_poison += 1
if poisoned:
success += 1
else:
false_accept += 1
else:
total_clean += 1
if poisoned:
false_reject += 1
else:
success += 1
far = false_accept/total_poison if total_poison != 0 else 0
frr = false_reject/total_clean if total_clean != 0 else 0
pbar.set_description(f"accuracy {success/(total_clean + total_poison)} FAR {far},{total_poison} FRR {frr},{total_clean}")
print(total_clean)
print(total_poison)
print(success)
print(false_accept)
print(false_reject)
else:
f = Image.open(args.image_path)
ori_img = f.convert('RGB')
ori_img = np.asarray(ori_img, dtype=np.float32)
ori_img = ori_img.transpose((2, 0, 1))
img = preprocess(ori_img)
poisoned, coordinates = detector_cleanse(img, model, clean_features, args.m, args.delta, args.alpha, args.iouthresh)
if poisoned:
print("\nImage is poisoned")
print(f"Coordinate : {coordinates}")
else:
print("\nImage is clean")
if __name__ == "__main__":
main()