Skip to content

Commit 5d88b19

Browse files
author
Iffi
authored
Merge pull request #10 from ikamensh/black-formatting
Black formatting for the whole repository
2 parents 18c0999 + e16d3b2 commit 5d88b19

File tree

113 files changed

+7804
-6031
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+7804
-6031
lines changed

.pre-commit-config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
repos:
2+
- repo: https://github.com/psf/black
3+
rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags
4+
hooks:
5+
- id: black
6+
language_version: python3 # Should be a command that runs python3.6+

DEVELOPMENT.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Code style
2+
The code follows code styling by [black](https://github.com/psf/black).
3+
4+
To automate code formatting, [pre-commit](https://github.com/pre-commit/pre-commit) is used, to run code checks before commiting changes.
5+
If you have pre-commit installed from the requirements-dev.txt simple run ``pre-commit install`` to install the hooks for this repo.

machin/auto/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,4 @@
44
from . import pl_logger
55
from . import pl_plugin
66

7-
__all__ = [
8-
"config", "dataset", "launcher", "pl_logger", "pl_plugin"
9-
]
7+
__all__ = ["config", "dataset", "launcher", "pl_logger", "pl_plugin"]

machin/auto/config.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import machin.frame.algorithms as algorithms
77

88

9-
def fill_default(default: Union[Dict[str, Any], Config],
10-
config: Union[Dict[str, Any], Config]):
9+
def fill_default(
10+
default: Union[Dict[str, Any], Config], config: Union[Dict[str, Any], Config]
11+
):
1112
for key in default:
1213
if key not in config:
1314
config[key] = default[key]
@@ -18,46 +19,56 @@ def _get_available_algorithms():
1819
algos = []
1920
for algo in dir(algorithms):
2021
algo_cls = getattr(algorithms, algo)
21-
if (inspect.isclass(algo_cls)
22-
and issubclass(algo_cls, TorchFramework)
23-
and algo_cls != TorchFramework):
22+
if (
23+
inspect.isclass(algo_cls)
24+
and issubclass(algo_cls, TorchFramework)
25+
and algo_cls != TorchFramework
26+
):
2427
algos.append(algo)
2528
return algos
2629

2730

28-
def generate_algorithm_config(algorithm: str,
29-
config: Union[Dict[str, Any], Config] = None):
31+
def generate_algorithm_config(
32+
algorithm: str, config: Union[Dict[str, Any], Config] = None
33+
):
3034
config = deepcopy(config) or {}
3135
if hasattr(algorithms, algorithm):
3236
algo_obj = getattr(algorithms, algorithm)
3337
if issubclass(algo_obj, TorchFramework):
3438
return algo_obj.generate_config(config)
35-
raise ValueError("Invalid algorithm: {}, valid ones are: {}"
36-
.format(algorithm, _get_available_algorithms()))
39+
raise ValueError(
40+
"Invalid algorithm: {}, valid ones are: {}".format(
41+
algorithm, _get_available_algorithms()
42+
)
43+
)
3744

3845

3946
def init_algorithm_from_config(config: Union[Dict[str, Any], Config]):
4047
assert_config_complete(config)
4148
frame = getattr(algorithms, config["frame"], None)
4249
if not inspect.isclass(frame) or not issubclass(frame, TorchFramework):
43-
raise ValueError("Invalid algorithm: {}, valid ones are: {}"
44-
.format(config["frame"], _get_available_algorithms()))
50+
raise ValueError(
51+
"Invalid algorithm: {}, valid ones are: {}".format(
52+
config["frame"], _get_available_algorithms()
53+
)
54+
)
4555
return frame.init_from_config(config)
4656

4757

4858
def is_algorithm_distributed(config: Union[Dict[str, Any], Config]):
4959
assert_config_complete(config)
5060
frame = getattr(algorithms, config["frame"], None)
5161
if not inspect.isclass(frame) or not issubclass(frame, TorchFramework):
52-
raise ValueError("Invalid algorithm: {}, valid ones are: {}"
53-
.format(config["frame"], _get_available_algorithms()))
62+
raise ValueError(
63+
"Invalid algorithm: {}, valid ones are: {}".format(
64+
config["frame"], _get_available_algorithms()
65+
)
66+
)
5467
return frame.is_distributed()
5568

5669

5770
def assert_config_complete(config: Union[Dict[str, Any], Config]):
5871
assert "frame" in config, 'Missing key "frame" in config.'
5972
assert "frame_config" in config, 'Missing key "frame_config" in config.'
60-
assert "train_env_config" in config, 'Missing key "train_env_config" ' \
61-
'in config.'
62-
assert "test_env_config" in config, 'Missing key "test_env_config" ' \
63-
'in config.'
73+
assert "train_env_config" in config, 'Missing key "train_env_config" ' "in config."
74+
assert "test_env_config" in config, 'Missing key "test_env_config" ' "in config."

machin/auto/dataset.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ def determine_precision(models):
1818
dtype.add(v.dtype)
1919
dtype = list(dtype)
2020
if len(dtype) > 1:
21-
raise RuntimeError("Multiple data types of parameters detected "
22-
"in models: {}, this is currently not supported "
23-
"since we need to determine the data type of your "
24-
"model input from your model parameter data type."
25-
.format(dtype))
21+
raise RuntimeError(
22+
"Multiple data types of parameters detected "
23+
"in models: {}, this is currently not supported "
24+
"since we need to determine the data type of your "
25+
"model input from your model parameter data type.".format(dtype)
26+
)
2627
return dtype[0]
2728

2829

@@ -43,9 +44,7 @@ def log_video(module, name, video_frames: List[np.ndarray]):
4344
# create video temp file
4445
_fd, path = tempfile.mkstemp(suffix=".gif")
4546
try:
46-
create_video(video_frames,
47-
os.path.dirname(path),
48-
os.path.basename(path))
47+
create_video(video_frames, os.path.dirname(path), os.path.basename(path))
4948
except Exception as e:
5049
print(e)
5150
os.remove(path)
@@ -58,10 +57,11 @@ def log_video(module, name, video_frames: List[np.ndarray]):
5857

5958

6059
class DatasetResult:
61-
def __init__(self,
62-
observations: List[Dict[str, Any]] = None,
63-
logs: List[Dict[str, Union[Scalar, Tuple[Scalar, str]]]]
64-
= None):
60+
def __init__(
61+
self,
62+
observations: List[Dict[str, Any]] = None,
63+
logs: List[Dict[str, Union[Scalar, Tuple[Scalar, str]]]] = None,
64+
):
6565
self.observations = observations or []
6666
self.logs = logs or []
6767

@@ -79,6 +79,7 @@ class RLDataset(IterableDataset):
7979
"""
8080
Base class for all RL Datasets.
8181
"""
82+
8283
def __init__(self, **_kwargs):
8384
super(RLDataset, self).__init__()
8485

0 commit comments

Comments
 (0)