Skip to content

Commit 6742f4c

Browse files
committed
Add license
1 parent 0e04336 commit 6742f4c

File tree

5 files changed

+96
-19
lines changed

5 files changed

+96
-19
lines changed

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Iffi
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
19+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
21+
OR OTHER DEALINGS IN THE SOFTWARE.

machin/frame/algorithms/a2c.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def update(self,
336336
act_policy_loss += (self.entropy_weight *
337337
new_action_entropy.mean())
338338

339-
act_policy_loss = act_policy_loss.sum()
339+
act_policy_loss = act_policy_loss.mean()
340340

341341
if self.visualize:
342342
self.visualize_model(act_policy_loss, "actor",

machin/frame/algorithms/ppo.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def update(self,
107107
"value", "gae"
108108
])
109109

110-
# normalize target value
111-
target_value = ((target_value - target_value.mean()) /
112-
(target_value.std() + 1e-5))
110+
# normalize advantage
113111
advantage = ((advantage - advantage.mean()) /
114112
(advantage.std() + 1e-6))
115113

@@ -138,7 +136,7 @@ def update(self,
138136
surr_loss_1 = sim_ratio * advantage
139137
surr_loss_2 = t.clamp(sim_ratio,
140138
1 - self.surr_clip,
141-
1 + self.surr_clip) * target_value
139+
1 + self.surr_clip) * advantage
142140

143141
# calculate policy loss using surrogate loss
144142
act_policy_loss = -t.min(surr_loss_1, surr_loss_2)
@@ -149,11 +147,6 @@ def update(self,
149147

150148
act_policy_loss = act_policy_loss.mean()
151149

152-
# calculate value loss
153-
value = self.criticize(state)
154-
value_loss = (self.criterion(target_value.to(value.device), value) *
155-
self.value_weight)
156-
157150
if self.visualize:
158151
self.visualize_model(act_policy_loss, "actor",
159152
self.visualize_dir)
@@ -168,6 +161,12 @@ def update(self,
168161
self.actor_optim.step()
169162
sum_act_policy_loss += act_policy_loss.item()
170163

164+
# calculate value loss
165+
value = self.criticize(state)
166+
value_loss = (self.criterion(target_value.to(value.device),
167+
value) *
168+
self.value_weight)
169+
171170
if self.visualize:
172171
self.visualize_model(value_loss, "critic",
173172
self.visualize_dir)

test/frame/algorithms/test_a2c.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,14 @@ def test_lr_scheduler(self, train_config, lr_a2c):
233233
def test_full_train(self, train_config, a2c, gae_lambda):
234234
c = train_config
235235
a2c.gae_lambda = gae_lambda
236+
236237
# begin training
237238
episode, step = Counter(), Counter()
238239
reward_fulfilled = Counter()
239240
smoother = Smooth()
240241
terminal = False
241242

242243
env = c.env
243-
t.set_printoptions(sci_mode=False)
244-
a2c.update_times = 1
245244
while episode < c.max_episodes:
246245
episode.count()
247246

test/frame/algorithms/test_ppo.py

+65-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from machin.utils.conf import Config
66
from machin.env.utils.openai_gym import disable_view_window
77
from torch.nn.functional import softplus
8-
from torch.distributions import Normal
8+
from torch.distributions import Normal, Categorical
99

1010
import pytest
1111
import torch as t
@@ -33,11 +33,29 @@ def forward(self, state, action=None):
3333
a_dist = Normal(a_mu, a_sigma)
3434
a = action if action is not None else a_dist.sample()
3535
a_entropy = a_dist.entropy()
36-
a = a.clamp(-self.action_range, self.action_range)
3736
a_log_prob = a_dist.log_prob(a)
3837
return a, a_log_prob, a_entropy
3938

4039

40+
# class Actor(nn.Module):
41+
# def __init__(self, state_dim, action_num):
42+
# super(Actor, self).__init__()
43+
#
44+
# self.fc1 = nn.Linear(state_dim, 16)
45+
# self.fc2 = nn.Linear(16, 16)
46+
# self.fc3 = nn.Linear(16, action_num)
47+
#
48+
# def forward(self, state, action=None):
49+
# a = t.relu(self.fc1(state))
50+
# a = t.relu(self.fc2(a))
51+
# probs = t.softmax(self.fc3(a), dim=1)
52+
# dist = Categorical(probs=probs)
53+
# act = (action if action is not None else dist.sample())
54+
# act_entropy = dist.entropy()
55+
# act_log_prob = dist.log_prob(act)
56+
# return act, act_log_prob, act_entropy
57+
58+
4159
class Critic(nn.Module):
4260
def __init__(self, state_dim):
4361
super(Critic, self).__init__()
@@ -58,20 +76,40 @@ class TestPPO(object):
5876
@pytest.fixture(scope="class")
5977
def train_config(self, pytestconfig):
6078
disable_view_window()
79+
t.manual_seed(0)
6180
c = Config()
6281
c.env_name = "Pendulum-v0"
6382
c.env = unwrap_time_limit(gym.make(c.env_name))
83+
c.env.seed(0)
6484
c.observe_dim = 3
6585
c.action_dim = 1
6686
c.action_range = 2
6787
c.max_episodes = 1000
68-
c.max_steps = 200
88+
c.max_steps = 500
6989
c.replay_size = 10000
7090
c.solved_reward = -150
7191
c.solved_repeat = 5
7292
c.device = "cpu"
7393
return c
7494

95+
# @pytest.fixture(scope="class")
96+
# def train_config(self, pytestconfig):
97+
# disable_view_window()
98+
# c = Config()
99+
# # Note: A2C is not sample efficient, it will not work very well
100+
# # in contiguous spaces such as "Pendulum-v0", PPO is better.
101+
# c.env_name = "CartPole-v1"
102+
# c.env = unwrap_time_limit(gym.make(c.env_name))
103+
# c.observe_dim = 4
104+
# c.action_num = 2
105+
# c.max_episodes = 1000
106+
# c.max_steps = 500
107+
# c.replay_size = 10000
108+
# c.solved_reward = 190
109+
# c.solved_repeat = 5
110+
# c.device = "cpu"
111+
# return c
112+
75113
@pytest.fixture(scope="function")
76114
def ppo(self, train_config):
77115
c = train_config
@@ -86,6 +124,20 @@ def ppo(self, train_config):
86124
replay_size=c.replay_size)
87125
return ppo
88126

127+
# @pytest.fixture(scope="function")
128+
# def ppo(self, train_config):
129+
# c = train_config
130+
# actor = smw(Actor(c.observe_dim, c.action_num)
131+
# .to(c.device), c.device, c.device)
132+
# critic = smw(Critic(c.observe_dim)
133+
# .to(c.device), c.device, c.device)
134+
# ppo = PPO(actor, critic,
135+
# t.optim.Adam,
136+
# nn.MSELoss(reduction='sum'),
137+
# replay_device=c.device,
138+
# replay_size=c.replay_size)
139+
# return ppo
140+
89141
@pytest.fixture(scope="function")
90142
def ppo_vis(self, train_config, tmpdir):
91143
# not used for training, only used for testing apis
@@ -169,15 +221,16 @@ def test_update(self, train_config, ppo_vis):
169221
########################################################################
170222
def test_full_train(self, train_config, ppo):
171223
c = train_config
172-
173224
# begin training
174225
episode, step = Counter(), Counter()
175226
reward_fulfilled = Counter()
176227
smoother = Smooth()
177228
terminal = False
178229

179230
env = c.env
180-
ppo.grad_max = 0.1
231+
ppo.gae_lambda = 1.0
232+
ppo.update_times = 20
233+
ppo.entropy_weight = 1
181234
while episode < c.max_episodes:
182235
episode.count()
183236

@@ -192,7 +245,11 @@ def test_full_train(self, train_config, ppo):
192245
old_state = state
193246
# agent model inference
194247
action = ppo.act({"state": old_state.unsqueeze(0)})[0]
195-
state, reward, terminal, _ = env.step(action.cpu().numpy())
248+
state, reward, terminal, _ = env.step(
249+
action.clamp(-c.action_range, c.action_range).cpu()
250+
.numpy()
251+
)
252+
#state, reward, terminal, _ = env.step(action.item())
196253
state = t.tensor(state, dtype=t.float32, device=c.device) \
197254
.flatten()
198255
total_reward += float(reward)
@@ -207,7 +264,8 @@ def test_full_train(self, train_config, ppo):
207264

208265
# update
209266
ppo.store_episode(tmp_observations)
210-
logger.info("{:.6f}, {:.0f}".format(*ppo.update()))
267+
if episode.get() % 5 == 0:
268+
logger.info("{:.6f}, {:.2f}".format(*ppo.update()))
211269

212270
smoother.update(total_reward)
213271
step.reset()

0 commit comments

Comments
 (0)