6
6
import machin .frame .algorithms as algorithms
7
7
8
8
9
- def fill_default (default : Union [Dict [str , Any ], Config ],
10
- config : Union [Dict [str , Any ], Config ]):
9
+ def fill_default (
10
+ default : Union [Dict [str , Any ], Config ], config : Union [Dict [str , Any ], Config ]
11
+ ):
11
12
for key in default :
12
13
if key not in config :
13
14
config [key ] = default [key ]
@@ -18,46 +19,56 @@ def _get_available_algorithms():
18
19
algos = []
19
20
for algo in dir (algorithms ):
20
21
algo_cls = getattr (algorithms , algo )
21
- if (inspect .isclass (algo_cls )
22
- and issubclass (algo_cls , TorchFramework )
23
- and algo_cls != TorchFramework ):
22
+ if (
23
+ inspect .isclass (algo_cls )
24
+ and issubclass (algo_cls , TorchFramework )
25
+ and algo_cls != TorchFramework
26
+ ):
24
27
algos .append (algo )
25
28
return algos
26
29
27
30
28
- def generate_algorithm_config (algorithm : str ,
29
- config : Union [Dict [str , Any ], Config ] = None ):
31
+ def generate_algorithm_config (
32
+ algorithm : str , config : Union [Dict [str , Any ], Config ] = None
33
+ ):
30
34
config = deepcopy (config ) or {}
31
35
if hasattr (algorithms , algorithm ):
32
36
algo_obj = getattr (algorithms , algorithm )
33
37
if issubclass (algo_obj , TorchFramework ):
34
38
return algo_obj .generate_config (config )
35
- raise ValueError ("Invalid algorithm: {}, valid ones are: {}"
36
- .format (algorithm , _get_available_algorithms ()))
39
+ raise ValueError (
40
+ "Invalid algorithm: {}, valid ones are: {}" .format (
41
+ algorithm , _get_available_algorithms ()
42
+ )
43
+ )
37
44
38
45
39
46
def init_algorithm_from_config (config : Union [Dict [str , Any ], Config ]):
40
47
assert_config_complete (config )
41
48
frame = getattr (algorithms , config ["frame" ], None )
42
49
if not inspect .isclass (frame ) or not issubclass (frame , TorchFramework ):
43
- raise ValueError ("Invalid algorithm: {}, valid ones are: {}"
44
- .format (config ["frame" ], _get_available_algorithms ()))
50
+ raise ValueError (
51
+ "Invalid algorithm: {}, valid ones are: {}" .format (
52
+ config ["frame" ], _get_available_algorithms ()
53
+ )
54
+ )
45
55
return frame .init_from_config (config )
46
56
47
57
48
58
def is_algorithm_distributed (config : Union [Dict [str , Any ], Config ]):
49
59
assert_config_complete (config )
50
60
frame = getattr (algorithms , config ["frame" ], None )
51
61
if not inspect .isclass (frame ) or not issubclass (frame , TorchFramework ):
52
- raise ValueError ("Invalid algorithm: {}, valid ones are: {}"
53
- .format (config ["frame" ], _get_available_algorithms ()))
62
+ raise ValueError (
63
+ "Invalid algorithm: {}, valid ones are: {}" .format (
64
+ config ["frame" ], _get_available_algorithms ()
65
+ )
66
+ )
54
67
return frame .is_distributed ()
55
68
56
69
57
70
def assert_config_complete (config : Union [Dict [str , Any ], Config ]):
58
71
assert "frame" in config , 'Missing key "frame" in config.'
59
72
assert "frame_config" in config , 'Missing key "frame_config" in config.'
60
- assert "train_env_config" in config , 'Missing key "train_env_config" ' \
61
- 'in config.'
62
- assert "test_env_config" in config , 'Missing key "test_env_config" ' \
63
- 'in config.'
73
+ assert "train_env_config" in config , 'Missing key "train_env_config" ' "in config."
74
+ assert "test_env_config" in config , 'Missing key "test_env_config" ' "in config."
0 commit comments