-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
90 lines (77 loc) · 2.19 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import torch
from transformers import CLIPModel, CLIPProcessor
from utils.tools import eval_all
import os
from peft import PeftModel, PeftConfig, get_peft_model
from utils.lora_moe import LoraMoEConfig
import random
import numpy as np
os.environ["TOKENIZERS_PARALLELISM"] = "false"
seed = 42
def main(args):
template = [
"itap of a {}.",
"a bad photo of the {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}.",
]
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch16",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).eval()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
model = PeftModel.from_pretrained(model, args.resume).to(torch.bfloat16)
print(model)
for name, param in model.named_modules():
if hasattr(param, "loramoe_router"):
param.close_task()
param.set_task_id(args.task_id)
# Eval
eval_all(
model=model,
processor=processor,
datasets=args.dataset,
use_cache=False,
template=template
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
default="imagenet_a",
help="Dataset name or list of dataset names"
)
parser.add_argument(
"--cache_config",
default="./configs",
help="Path to cache configuration file"
)
parser.add_argument(
"--task_id",
default=None,
type=int,
help="Task ID for the lora",
)
parser.add_argument(
"--resume",
type=str,
help="Path to the pretrained model"
)
args = parser.parse_args()
args.dataset = args.dataset.split(",")
return args
if __name__ == "__main__":
main(get_args())