Skip to content

Commit 0b97fc7

Browse files
committed
Merge branch 'master' into dev
2 parents 0f4e630 + faf7c72 commit 0b97fc7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+437
-402
lines changed

Jenkinsfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ pipeline {
1010
}
1111
environment {
1212
PYPI_CREDS = credentials('pypi_username_password')
13-
TWINE_USERNAME = '${env.PYPI_CREDS_USR}'
14-
TWINE_PASSWORD = '${env.PYPI_CREDS_PSW}'
13+
TWINE_USERNAME = "${env.PYPI_CREDS_USR}"
14+
TWINE_PASSWORD = "${env.PYPI_CREDS_PSW}"
1515
// See https://github.com/pytorch/pytorch/issues/37377
1616
MKL_SERVICE_FORCE_INTEL = "1"
1717
}

machin/frame/algorithms/trpo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,10 @@ def fvp(v):
241241

242242
# usually 1e-15 is low enough
243243
if t.allclose(loss_grad, t.zeros_like(loss_grad), atol=1e-15):
244-
default_logger.warning("TRPO detects zero gradient.")
244+
default_logger.warning(
245+
"TRPO detects zero gradient, update step skipped."
246+
)
247+
return 0, 0
245248

246249
step_dir = self._conjugate_gradients(
247250
fvp,

machin/model/algorithms/trpo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def sample(self, probability: t.tensor, action=None):
2626
Action log probability tensor of shape ``[batch, 1]``.
2727
"""
2828
batch_size = probability.shape[0]
29-
self.action_param = probability
29+
# dx (xlnx) = lnx + 1, x must > 0
30+
self.action_param = probability + 1e-6
3031
dist = Categorical(probs=probability)
3132
if action is None:
3233
action = dist.sample()
@@ -41,7 +42,7 @@ def get_kl(self, *args, **kwargs):
4142
self.forward(*args, **kwargs)
4243
action_prob1 = self.action_param
4344
action_prob0 = action_prob1.detach()
44-
kl = action_prob0 * (t.log(action_prob0) - t.log(action_prob1))
45+
kl = action_prob0 * (t.log(action_prob0 / action_prob1))
4546
return kl.sum(1, keepdim=True)
4647

4748
def compare_kl(self, params: t.tensor, *args, **kwargs):

test/auto/env/test_openai_gym.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from test.util_platforms import linux_only_forall
1+
from torch.distributions import Categorical, Normal
2+
from pytorch_lightning.callbacks import Callback
3+
from pytorch_lightning.utilities.distributed import ReduceOp
4+
from machin.parallel.distributed import get_cur_rank
5+
from machin.parallel.thread import Thread
6+
from machin.parallel.queue import SimpleQueue, TimeoutError
7+
from machin.utils.logging import default_logger
28
from machin.auto.config import (
39
generate_training_config,
410
generate_algorithm_config,
@@ -11,25 +17,18 @@
1117
gym_env_dataset_creator,
1218
launch,
1319
)
20+
from test.util_run_multi import *
21+
from test.util_fixtures import *
22+
from test.util_platforms import linux_only_forall
23+
1424
import os
1525
import pickle
1626
import os.path as p
1727
import gym
18-
import pytest
1928
import torch as t
2029
import torch.nn as nn
2130
import torch.nn.functional as F
2231
import subprocess as sp
23-
import multiprocessing as mp
24-
from test.util_run_multi import *
25-
from test.util_fixtures import *
26-
from pytorch_lightning.callbacks import Callback
27-
from torch.distributions import Categorical, Normal
28-
from machin.parallel.distributed import get_cur_rank
29-
from machin.parallel.thread import Thread
30-
from machin.parallel.queue import SimpleQueue, TimeoutError
31-
from machin.utils.logging import default_logger
32-
from pytorch_lightning.utilities.distributed import ReduceOp
3332

3433
linux_only_forall()
3534

test/auto/test_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from unittest import mock
12
from machin.frame.algorithms import DQN
23
from machin.auto.launcher import Launcher
34
from machin.auto.dataset import RLDataset, DatasetResult
4-
from unittest import mock
55
import pytest
66
import torch as t
77
import torch.nn as nn

test/auto/test_pl_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from machin.auto.pl_logger import LocalMediaLogger
2-
from pytorch_lightning.loggers.base import DummyExperiment
31
from PIL import Image
2+
from pytorch_lightning.loggers.base import DummyExperiment
3+
from machin.auto.pl_logger import LocalMediaLogger
44
import os
55
import matplotlib.pyplot as plt
66

test/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@ def pytest_addoption(parser):
33
"--gpu_device",
44
action="store",
55
default=None,
6-
help="Gpu device descriptor in pytorch",
6+
help="GPU device descriptor in pytorch",
7+
)
8+
parser.addoption(
9+
"--multiprocess_method",
10+
default="forkserver",
11+
help="spawn or forkserver, default is forkserver",
712
)

test/data/all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import os
2-
import re
31
from . import generators, ROOT
42
from .archive import Archive
3+
import os
4+
import re
55

66

77
def first(iterable, condition=lambda x: True):

test/data/archive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import torch as t
21
import os
32
import re
43
import datetime
4+
import torch as t
55

66

77
class Archive:

test/data/generators/generate_gail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from torch.distributions import Categorical
12
from machin.frame.algorithms import PPO
23
from machin.utils.logging import default_logger as logger
3-
from torch.distributions import Categorical
44
from test.data import ROOT
55
from test.data.archive import Archive, get_time_string
66
import os

test/env/wrappers/test_openai_gym.py

Lines changed: 47 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66
77
Submit us a issue if you have found any problem.
88
"""
9-
from test.util_platforms import linux_only_forall
10-
11-
linux_only_forall()
12-
9+
from random import choice, sample
1310
from machin.env.wrappers import openai_gym
1411
from machin.utils.logging import default_logger
15-
from random import choice, sample
12+
from test.util_platforms import linux_only_forall
13+
1614
import pytest
1715
import gym
1816
import numpy as np
1917

18+
linux_only_forall()
2019
ENV_NUM = 2
2120
SAMPLE_NUM = 2
2221
WORKER_NUM = 2
@@ -26,69 +25,15 @@ def mock_action(action_space: gym.spaces.Space):
2625
return action_space.sample()
2726

2827

29-
def prepare_envs(env_list):
30-
for env in env_list:
31-
env.reset()
32-
33-
34-
def should_skip(spec):
35-
# From gym/envs/tests/spec_list.py
36-
# Used to check whether a gym environment should be tested.
37-
38-
# We skip tests for envs that require dependencies or are otherwise
39-
# troublesome to run frequently
40-
ep = spec.entry_point
41-
42-
# No need to test unittest environments
43-
if ep.startswith("gym.envs.unittest"):
44-
return True
45-
46-
# Skip not renderable tests
47-
if ep.startswith("gym.envs.algorithmic") or ep.startswith("gym.envs.toy_text"):
48-
return True
49-
50-
# Skip mujoco tests
51-
if ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:"):
52-
return True
53-
54-
# Skip atari tests
55-
if ep.startswith("gym.envs.atari"):
56-
return True
57-
58-
# Skip other tests
59-
if "GoEnv" in ep or "HexEnv" in ep or "CarRacing" in ep:
60-
return True
61-
62-
# Conditionally skip box2d tests
63-
try:
64-
import Box2D
65-
except ImportError:
66-
if ep.startswith("gym.envs.box2d"):
67-
return True
68-
69-
return False
70-
71-
7228
@pytest.fixture(scope="module", autouse=True)
7329
def envs():
74-
all_envs = []
75-
env_map = {}
76-
# Find the newest version of non-skippable environments.
77-
for env_raw_name, env_spec in gym.envs.registry.env_specs.items():
78-
if not should_skip(env_spec):
79-
env_name, env_version = env_raw_name.split("-v")
80-
if env_name not in env_version or int(env_version) > env_map[env_name]:
81-
env_map[env_name] = int(env_version)
30+
names = ["CartPole-v0"]
31+
creators = []
8232

8333
# Create environments.
84-
for env_name, env_version in env_map.items():
85-
env_name = env_name + "-v" + str(env_version)
86-
default_logger.info(f"OpenAI gym {env_name} added")
87-
all_envs.append([lambda *_: gym.make(env_name) for _ in range(ENV_NUM)])
88-
default_logger.info(
89-
"{} OpenAI gym environments to be tested.".format(len(all_envs))
90-
)
91-
return all_envs
34+
for name in names:
35+
creators.append([lambda *_: gym.make(name) for _ in range(ENV_NUM)])
36+
return names, creators
9237

9338

9439
class TestParallelWrapperDummy:
@@ -104,8 +49,9 @@ class TestParallelWrapperDummy:
10449

10550
@pytest.mark.parametrize("idx,reset_num", param_test_reset)
10651
def test_reset(self, envs, idx, reset_num):
107-
for env_list in envs:
108-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
52+
for name, creators in zip(*envs):
53+
default_logger.info(f"Testing on env {name}")
54+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
10955
obsrvs = dummy_wrapper.reset(idx)
11056
dummy_wrapper.close()
11157

@@ -129,8 +75,9 @@ def test_reset(self, envs, idx, reset_num):
12975

13076
@pytest.mark.parametrize("idx,act_num", param_test_step)
13177
def test_step(self, envs, idx, act_num):
132-
for env_list in envs:
133-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
78+
for name, creators in zip(*envs):
79+
default_logger.info(f"Testing on env {name}")
80+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
13481
action = [mock_action(dummy_wrapper.action_space) for _ in range(act_num)]
13582
dummy_wrapper.reset(idx)
13683
obsrvs, reward, terminal, info = dummy_wrapper.step(action, idx)
@@ -159,8 +106,9 @@ def test_step(self, envs, idx, act_num):
159106

160107
@pytest.mark.parametrize("idx", param_test_seed)
161108
def test_seed(self, envs, idx):
162-
for env_list in envs:
163-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
109+
for name, creators in zip(*envs):
110+
default_logger.info(f"Testing on env {name}")
111+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
164112
seeds = dummy_wrapper.seed()
165113
dummy_wrapper.close()
166114
assert len(seeds) == ENV_NUM
@@ -177,8 +125,9 @@ def test_seed(self, envs, idx):
177125

178126
@pytest.mark.parametrize("idx,render_num", param_test_render)
179127
def test_render(self, envs, idx, render_num):
180-
for env_list in envs:
181-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
128+
for name, creators in zip(*envs):
129+
default_logger.info(f"Testing on env {name}")
130+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
182131
dummy_wrapper.reset(idx)
183132
rendered = dummy_wrapper.render(idx)
184133
dummy_wrapper.close()
@@ -190,16 +139,18 @@ def test_render(self, envs, idx, render_num):
190139
# Test for ParallelWrapperDummy.close
191140
########################################################################
192141
def test_close(self, envs):
193-
for env_list in envs:
194-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
142+
for name, creators in zip(*envs):
143+
default_logger.info(f"Testing on env {name}")
144+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
195145
dummy_wrapper.close()
196146

197147
########################################################################
198148
# Test for ParallelWrapperDummy.active
199149
########################################################################
200150
def test_active(self, envs):
201-
for env_list in envs:
202-
dummy_wrapper = openai_gym.ParallelWrapperDummy(env_list)
151+
for name, creators in zip(*envs):
152+
default_logger.info(f"Testing on env {name}")
153+
dummy_wrapper = openai_gym.ParallelWrapperDummy(creators)
203154
dummy_wrapper.reset()
204155
active = dummy_wrapper.active()
205156
dummy_wrapper.close()
@@ -209,7 +160,7 @@ def test_active(self, envs):
209160
# Test for ParallelWrapperDummy.size
210161
########################################################################
211162
def test_size(self, envs):
212-
dummy_wrapper = openai_gym.ParallelWrapperDummy(envs[0])
163+
dummy_wrapper = openai_gym.ParallelWrapperDummy(envs[1][0])
213164
assert dummy_wrapper.size() == ENV_NUM
214165
dummy_wrapper.close()
215166

@@ -227,8 +178,9 @@ class TestParallelWrapperSubProc:
227178

228179
@pytest.mark.parametrize("idx,reset_num", param_test_reset)
229180
def test_reset(self, envs, idx, reset_num):
230-
for env_list in envs:
231-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
181+
for name, creators in zip(*envs):
182+
default_logger.info(f"Testing on env {name}")
183+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
232184
obsrvs = subproc_wrapper.reset(idx)
233185
subproc_wrapper.close()
234186

@@ -252,8 +204,9 @@ def test_reset(self, envs, idx, reset_num):
252204

253205
@pytest.mark.parametrize("idx,act_num", param_test_step)
254206
def test_step(self, envs, idx, act_num):
255-
for env_list in envs:
256-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
207+
for name, creators in zip(*envs):
208+
default_logger.info(f"Testing on env {name}")
209+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
257210
action = [mock_action(subproc_wrapper.action_space) for _ in range(act_num)]
258211
subproc_wrapper.reset(idx)
259212
obsrvs, reward, terminal, info = subproc_wrapper.step(action, idx)
@@ -282,8 +235,9 @@ def test_step(self, envs, idx, act_num):
282235

283236
@pytest.mark.parametrize("idx", param_test_seed)
284237
def test_seed(self, envs, idx):
285-
for env_list in envs:
286-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
238+
for name, creators in zip(*envs):
239+
default_logger.info(f"Testing on env {name}")
240+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
287241
seeds = subproc_wrapper.seed()
288242
subproc_wrapper.close()
289243
assert len(seeds) == ENV_NUM
@@ -300,8 +254,9 @@ def test_seed(self, envs, idx):
300254

301255
@pytest.mark.parametrize("idx,render_num", param_test_render)
302256
def test_render(self, envs, idx, render_num):
303-
for env_list in envs:
304-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
257+
for name, creators in zip(*envs):
258+
default_logger.info(f"Testing on env {name}")
259+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
305260
subproc_wrapper.reset(idx)
306261
rendered = subproc_wrapper.render(idx)
307262
subproc_wrapper.close()
@@ -313,22 +268,24 @@ def test_render(self, envs, idx, render_num):
313268
# Test for ParallelWrapperSubProc.close
314269
########################################################################
315270
def test_close(self, envs):
316-
for env_list in envs:
317-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
271+
for name, creators in zip(*envs):
272+
default_logger.info(f"Testing on env {name}")
273+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
318274
subproc_wrapper.close()
319275

320276
########################################################################
321277
# Test for ParallelWrapperSubProc.active
322278
########################################################################
323279
def test_active(self, envs):
324-
for env_list in envs:
325-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(env_list)
280+
for name, creators in zip(*envs):
281+
default_logger.info(f"Testing on env {name}")
282+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(creators)
326283
subproc_wrapper.reset()
327284
active = subproc_wrapper.active()
328285
subproc_wrapper.close()
329286
assert len(active) == ENV_NUM
330287

331288
def test_size(self, envs):
332-
subproc_wrapper = openai_gym.ParallelWrapperSubProc(envs[0])
289+
subproc_wrapper = openai_gym.ParallelWrapperSubProc(envs[1][0])
333290
assert subproc_wrapper.size() == ENV_NUM
334291
subproc_wrapper.close()

0 commit comments

Comments
 (0)