-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
51 lines (45 loc) · 1.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from src.imports import *
from src.navqt import NAVQT
def run():
try:
args = sys.argv[-1].split("|")
except:
args = []
print("ARGUMENTS")
print(args)
print("---------")
qc = NAVQT()
kwargs = {"savepth": "./results/"}
for _, arg in enumerate(args):
try:
var = arg.split("=")[0]
if type(getattr(qc, var)) is bool:
val = arg.split("=")[1].lower() == "true"
elif type(getattr(qc, var)) is int:
val = int(arg.split("=")[1])
elif type(getattr(qc, var)) is float:
val = float(arg.split("=")[1])
elif type(getattr(qc, var)) is str:
val = arg.split("=")[1]
else:
val = None
print("COULD NOT FIND VARIABLE:", var)
kwargs.update({var: val})
print(var, ":", val)
except:
if "main.py" not in arg:
print("Trouble with " + arg)
pass
qc = NAVQT(**kwargs)
if not os.path.isfile(qc.savepth + "history---" + qc.settings + ".pdf"):
print(qc)
qc.train(n_epochs=qc.max_iter, early_stop=True, grad_norm=True)
qc.plot_history(save=True)
print(
"Succesfully saved file(s) to:",
qc.savepth + "history---" + qc.settings + ".*",
)
else:
print("File exists!")
if __name__ == "__main__":
run()