Skip to content

Commit f2036a3

Browse files
committed
Finish adding auto config generator
1 parent 8197d5a commit f2036a3

15 files changed

+212
-78
lines changed

machin/__main__.py

Whitespace-only changes.

machin/auto/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from . import envs
12
from . import config
23
from . import dataset
34
from . import launcher
45
from . import pl_logger
56
from . import pl_plugin
67

7-
__all__ = ["config", "dataset", "launcher", "pl_logger", "pl_plugin"]
8+
__all__ = ["env", "config", "dataset", "launcher", "pl_logger", "pl_plugin"]

machin/auto/__main__.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import json
2+
import argparse
3+
from pprint import pprint
4+
from machin.auto.config import (
5+
get_available_algorithms,
6+
get_available_environments,
7+
generate_algorithm_config,
8+
generate_env_config,
9+
generate_training_config,
10+
)
11+
12+
if __name__ == "__main__":
13+
parser = argparse.ArgumentParser()
14+
subparsers = parser.add_subparsers(dest="command")
15+
16+
p_list = subparsers.add_parser(
17+
"list", help="List available algorithms or environments."
18+
)
19+
20+
p_list.add_argument(
21+
"--algo", action="store_true", help="List available algorithms.",
22+
)
23+
24+
p_list.add_argument(
25+
"--env", action="store_true", help="List available environments."
26+
)
27+
28+
p_generate = subparsers.add_parser("generate", help="Generate configuration.")
29+
30+
p_generate.add_argument(
31+
"--algo", type=str, required=True, help="Algorithm name to use."
32+
)
33+
p_generate.add_argument(
34+
"--env", type=str, required=True, help="Environment name to use."
35+
)
36+
p_generate.add_argument(
37+
"--print", action="store_true", help="Direct config output to screen."
38+
)
39+
p_generate.add_argument(
40+
"--output",
41+
type=str,
42+
default="config.json",
43+
help="JSON config file output path.",
44+
)
45+
46+
args = parser.parse_args()
47+
if args.command == "list":
48+
if args.env:
49+
print("Available environments are:")
50+
for env in get_available_environments():
51+
print(env)
52+
elif args.algo:
53+
print("Available algorithms are:")
54+
for algo in get_available_algorithms():
55+
print(algo)
56+
else:
57+
print("You can list --algo or --env.")
58+
59+
elif args.command == "generate":
60+
if args.algo not in get_available_algorithms():
61+
print(
62+
f"{args.algo} is not a valid algorithm, use list "
63+
"--algo to get a list of available algorithms."
64+
)
65+
exit(0)
66+
if args.env not in get_available_environments():
67+
print(
68+
f"{args.env} is not a valid environment, use list "
69+
"--env to get a list of available environments."
70+
)
71+
exit(0)
72+
config = {}
73+
config = generate_env_config(args.env, config=config)
74+
config = generate_algorithm_config(args.algo, config=config)
75+
config = generate_training_config(config=config)
76+
77+
if args.print:
78+
pprint(config)
79+
80+
with open(args.output, "w") as f:
81+
json.dump(config, f, indent=4, sort_keys=True)
82+
print(f"Config saved to {args.output}")

machin/auto/config.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict, Any, Union
33
from machin.frame.algorithms import TorchFramework
44
from machin.utils.conf import Config
5+
from . import envs
56
import inspect
67
import torch as t
78
import machin.frame.algorithms as algorithms
@@ -16,7 +17,7 @@ def fill_default(
1617
return config
1718

1819

19-
def _get_available_algorithms():
20+
def get_available_algorithms():
2021
algos = []
2122
for algo in dir(algorithms):
2223
algo_cls = getattr(algorithms, algo)
@@ -29,8 +30,17 @@ def _get_available_algorithms():
2930
return algos
3031

3132

33+
def get_available_environments():
34+
environments = []
35+
for e in dir(envs):
36+
e_module = getattr(envs, e)
37+
if hasattr(e_module, "launch") and hasattr(e_module, "generate_env_config"):
38+
environments.append(e)
39+
return environments
40+
41+
3242
def generate_training_config(
33-
root_dir: str = "./trial",
43+
root_dir: str = "trial",
3444
episode_per_epoch: int = 10,
3545
max_episodes: int = 10000,
3646
config: Union[Dict[str, Any], Config] = None,
@@ -56,10 +66,23 @@ def generate_algorithm_config(
5666
config["gpus"] = [0, 0, 0]
5767
config["num_processes"] = 3
5868
config["num_nodes"] = 1
59-
config["batch_num"] = {"sampler": 10, "learner": 1}
69+
else:
70+
config["gpus"] = [0]
6071
return config
6172
raise ValueError(
62-
f"Invalid algorithm: {algorithm}, valid ones are: {_get_available_algorithms()}"
73+
f"Invalid algorithm: {algorithm}, valid ones are: {get_available_algorithms()}"
74+
)
75+
76+
77+
def generate_env_config(environment: str, config: Union[Dict[str, Any], Config] = None):
78+
config = deepcopy(config) or {}
79+
if hasattr(envs, environment):
80+
e_module = getattr(envs, environment)
81+
if hasattr(e_module, "launch") and hasattr(e_module, "generate_env_config"):
82+
return e_module.generate_env_config(config)
83+
raise ValueError(
84+
f"Invalid environment: {environment}, "
85+
f"valid ones are: {get_available_algorithms()}"
6386
)
6487

6588

@@ -71,7 +94,7 @@ def init_algorithm_from_config(
7194
if not inspect.isclass(frame) or not issubclass(frame, TorchFramework):
7295
raise ValueError(
7396
f"Invalid algorithm: {config['frame']}, "
74-
f"valid ones are: {_get_available_algorithms()}"
97+
f"valid ones are: {get_available_algorithms()}"
7598
)
7699
return frame.init_from_config(config, model_device=model_device)
77100

@@ -82,7 +105,7 @@ def is_algorithm_distributed(config: Union[Dict[str, Any], Config]):
82105
if not inspect.isclass(frame) or not issubclass(frame, TorchFramework):
83106
raise ValueError(
84107
f"Invalid algorithm: {config['frame']}, "
85-
f"valid ones are: {_get_available_algorithms()}"
108+
f"valid ones are: {get_available_algorithms()}"
86109
)
87110
return frame.is_distributed()
88111

machin/auto/envs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import openai_gym
2+
3+
__all__ = ["openai_gym"]

machin/auto/env/openai_gym.py renamed to machin/auto/envs/openai_gym.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -244,34 +244,6 @@ def __next__(self):
244244
return result
245245

246246

247-
def generate_gym_env_config(
248-
env_name: str = None, config: Union[Dict[str, Any], Config] = None
249-
):
250-
"""
251-
Generate example OpenAI gym config.
252-
"""
253-
config = deepcopy(config) or {}
254-
return fill_default(
255-
{
256-
"trials_dir": "trials",
257-
"gpus": 0,
258-
"episode_per_epoch": 100,
259-
"max_episodes": 1000000,
260-
"train_env_config": {
261-
"env_name": env_name or "CartPole-v1",
262-
"render_every_episode": 100,
263-
"act_kwargs": {},
264-
},
265-
"test_env_config": {
266-
"env_name": env_name or "CartPole-v1",
267-
"render_every_episode": 100,
268-
"act_kwargs": {},
269-
},
270-
},
271-
config,
272-
)
273-
274-
275247
def gym_env_dataset_creator(frame, env_config):
276248
env = gym.make(env_config["env_name"])
277249
if _is_discrete_space(env.action_space):
@@ -295,9 +267,31 @@ def gym_env_dataset_creator(frame, env_config):
295267
)
296268

297269

298-
def launch_gym(
299-
config: Union[Dict[str, Any], Config], pl_callbacks: List[Callback] = None
270+
def generate_env_config(
271+
env_name: str = None, config: Union[Dict[str, Any], Config] = None
300272
):
273+
"""
274+
Generate example OpenAI gym config.
275+
"""
276+
config = deepcopy(config) or {}
277+
return fill_default(
278+
{
279+
"train_env_config": {
280+
"env_name": env_name or "CartPole-v1",
281+
"render_every_episode": 100,
282+
"act_kwargs": {},
283+
},
284+
"test_env_config": {
285+
"env_name": env_name or "CartPole-v1",
286+
"render_every_episode": 100,
287+
"act_kwargs": {},
288+
},
289+
},
290+
config,
291+
)
292+
293+
294+
def launch(config: Union[Dict[str, Any], Config], pl_callbacks: List[Callback] = None):
301295
"""
302296
Args:
303297
config: All configs needed to launch a gym environment and initialize

machin/auto/pl_plugin.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import logging
33
import pytorch_lightning as pl
4+
from time import sleep
45
from torch import distributed
56
from pytorch_lightning.utilities.seed import seed_everything
67
from pytorch_lightning.utilities.distributed import rank_zero_only
@@ -105,8 +106,12 @@ def pre_dispatch(self):
105106

106107
# initialize framework in the launcher
107108
self._model.init_frame()
108-
self._model.trainer.accelerator.optimizers = self._model.frame.optimizers
109-
self._model.trainer.accelerator.lr_schedulers = self._model.frame.lr_schedulers
109+
if self._model.frame.optimizers is not None:
110+
self._model.trainer.accelerator.optimizers = self._model.frame.optimizers
111+
if self._model.frame.lr_schedulers is not None:
112+
self._model.trainer.accelerator.lr_schedulers = (
113+
self._model.frame.lr_schedulers
114+
)
110115

111116
self.barrier()
112117

@@ -199,8 +204,12 @@ def new_process(self, process_idx, trainer, mp_queue):
199204

200205
# initialize framework in the launcher
201206
self._model.init_frame()
202-
trainer.accelerator.optimizers = self._model.frame.optimizers
203-
trainer.accelerator.lr_schedulers = self._model.frame.lr_schedulers
207+
if self._model.frame.optimizers is not None:
208+
self._model.trainer.accelerator.optimizers = self._model.frame.optimizers
209+
if self._model.frame.lr_schedulers is not None:
210+
self._model.trainer.accelerator.lr_schedulers = (
211+
self._model.frame.lr_schedulers
212+
)
204213

205214
self.barrier()
206215

@@ -231,7 +240,7 @@ def _spawn(self):
231240
]
232241
for p in processes:
233242
p.start()
234-
while all([p.is_alive() for p in processes]):
243+
while True:
235244
should_exit = False
236245
for p in processes:
237246
try:
@@ -240,9 +249,14 @@ def _spawn(self):
240249
traceback.print_exc()
241250
should_exit = True
242251
if should_exit:
252+
for p in processes:
253+
p.terminate()
254+
p.join()
255+
raise RuntimeError("One or more exceptions raised in sub-processes.")
256+
elif not all([p.is_alive() for p in processes]):
243257
break
258+
sleep(0.1)
244259
for p in processes:
245-
p.kill()
246260
p.join()
247261

248262
def training_step(self, *args, **kwargs):
@@ -266,4 +280,3 @@ def post_training_step(self):
266280
# before the trainer is initialized.
267281
pl.trainer.connectors.accelerator_connector.DDPPlugin = DDPPlugin
268282
pl.trainer.connectors.accelerator_connector.DDPSpawnPlugin = DDPSpawnPlugin
269-
pl_logger.info("DDP plugin patched.")

machin/frame/algorithms/a3c.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def optimizers(self, optimizers):
123123
def lr_schedulers(self):
124124
return []
125125

126+
@classmethod
127+
def is_distributed(cls):
128+
return True
129+
126130
def set_sync(self, is_syncing):
127131
self.is_syncing = is_syncing
128132

machin/frame/algorithms/apex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def generate_config(cls, config: Dict[str, Any]):
187187
}
188188
config = deepcopy(config)
189189
config["frame"] = "DQNApex"
190+
config["batch_num"] = {"sampler": 10, "learner": 1}
190191
if "frame_config" not in config:
191192
config["frame_config"] = default_values
192193
else:
@@ -461,6 +462,7 @@ def generate_config(cls, config: Union[Dict[str, Any], Config]):
461462
}
462463
config = deepcopy(config)
463464
config["frame"] = "DDPGApex"
465+
config["batch_num"] = {"sampler": 10, "learner": 1}
464466
if "frame_config" not in config:
465467
config["frame_config"] = default_values
466468
else:

machin/frame/algorithms/impala.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def generate_config(cls, config: Union[Dict[str, Any], Config]):
484484
}
485485
config = deepcopy(config)
486486
config["frame"] = "IMPALA"
487+
config["batch_num"] = {"sampler": 10, "learner": 1}
487488
if "frame_config" not in config:
488489
config["frame_config"] = default_values
489490
else:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"psutil",
5858
"numpy",
5959
"torch>=1.6.0",
60-
"pytorch-lightning>=1.0",
60+
"pytorch-lightning>=1.2.0",
6161
"torchviz",
6262
"moviepy",
6363
"matplotlib",

test/auto/_pl_plugin_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from machin.parallel.distributed import get_world, get_cur_rank
2+
from machin.utils.helper_classes import Object
23
from torch.utils.data import DataLoader, TensorDataset
34
import os
45
import sys
@@ -24,6 +25,7 @@ class ParallelModule(pl.LightningModule):
2425
def __init__(self):
2526
super().__init__()
2627
self.nn_model = NNModule()
28+
self.frame = Object({"optimizers": None, "lr_schedulers": None})
2729

2830
def train_dataloader(self):
2931
return DataLoader(
@@ -41,6 +43,9 @@ def training_step(self, batch, _batch_idx):
4143
raise RuntimeError("World not initialized.")
4244
return None
4345

46+
def init_frame(self):
47+
pass
48+
4449
def configure_optimizers(self):
4550
return None
4651

0 commit comments

Comments
 (0)