5
5
from machin .utils .conf import Config
6
6
from machin .env .utils .openai_gym import disable_view_window
7
7
from torch .nn .functional import softplus
8
- from torch .distributions import Normal
8
+ from torch .distributions import Normal , Categorical
9
9
10
10
import pytest
11
11
import torch as t
@@ -33,11 +33,29 @@ def forward(self, state, action=None):
33
33
a_dist = Normal (a_mu , a_sigma )
34
34
a = action if action is not None else a_dist .sample ()
35
35
a_entropy = a_dist .entropy ()
36
- a = a .clamp (- self .action_range , self .action_range )
37
36
a_log_prob = a_dist .log_prob (a )
38
37
return a , a_log_prob , a_entropy
39
38
40
39
40
+ # class Actor(nn.Module):
41
+ # def __init__(self, state_dim, action_num):
42
+ # super(Actor, self).__init__()
43
+ #
44
+ # self.fc1 = nn.Linear(state_dim, 16)
45
+ # self.fc2 = nn.Linear(16, 16)
46
+ # self.fc3 = nn.Linear(16, action_num)
47
+ #
48
+ # def forward(self, state, action=None):
49
+ # a = t.relu(self.fc1(state))
50
+ # a = t.relu(self.fc2(a))
51
+ # probs = t.softmax(self.fc3(a), dim=1)
52
+ # dist = Categorical(probs=probs)
53
+ # act = (action if action is not None else dist.sample())
54
+ # act_entropy = dist.entropy()
55
+ # act_log_prob = dist.log_prob(act)
56
+ # return act, act_log_prob, act_entropy
57
+
58
+
41
59
class Critic (nn .Module ):
42
60
def __init__ (self , state_dim ):
43
61
super (Critic , self ).__init__ ()
@@ -58,20 +76,40 @@ class TestPPO(object):
58
76
@pytest .fixture (scope = "class" )
59
77
def train_config (self , pytestconfig ):
60
78
disable_view_window ()
79
+ t .manual_seed (0 )
61
80
c = Config ()
62
81
c .env_name = "Pendulum-v0"
63
82
c .env = unwrap_time_limit (gym .make (c .env_name ))
83
+ c .env .seed (0 )
64
84
c .observe_dim = 3
65
85
c .action_dim = 1
66
86
c .action_range = 2
67
87
c .max_episodes = 1000
68
- c .max_steps = 200
88
+ c .max_steps = 500
69
89
c .replay_size = 10000
70
90
c .solved_reward = - 150
71
91
c .solved_repeat = 5
72
92
c .device = "cpu"
73
93
return c
74
94
95
+ # @pytest.fixture(scope="class")
96
+ # def train_config(self, pytestconfig):
97
+ # disable_view_window()
98
+ # c = Config()
99
+ # # Note: A2C is not sample efficient, it will not work very well
100
+ # # in contiguous spaces such as "Pendulum-v0", PPO is better.
101
+ # c.env_name = "CartPole-v1"
102
+ # c.env = unwrap_time_limit(gym.make(c.env_name))
103
+ # c.observe_dim = 4
104
+ # c.action_num = 2
105
+ # c.max_episodes = 1000
106
+ # c.max_steps = 500
107
+ # c.replay_size = 10000
108
+ # c.solved_reward = 190
109
+ # c.solved_repeat = 5
110
+ # c.device = "cpu"
111
+ # return c
112
+
75
113
@pytest .fixture (scope = "function" )
76
114
def ppo (self , train_config ):
77
115
c = train_config
@@ -86,6 +124,20 @@ def ppo(self, train_config):
86
124
replay_size = c .replay_size )
87
125
return ppo
88
126
127
+ # @pytest.fixture(scope="function")
128
+ # def ppo(self, train_config):
129
+ # c = train_config
130
+ # actor = smw(Actor(c.observe_dim, c.action_num)
131
+ # .to(c.device), c.device, c.device)
132
+ # critic = smw(Critic(c.observe_dim)
133
+ # .to(c.device), c.device, c.device)
134
+ # ppo = PPO(actor, critic,
135
+ # t.optim.Adam,
136
+ # nn.MSELoss(reduction='sum'),
137
+ # replay_device=c.device,
138
+ # replay_size=c.replay_size)
139
+ # return ppo
140
+
89
141
@pytest .fixture (scope = "function" )
90
142
def ppo_vis (self , train_config , tmpdir ):
91
143
# not used for training, only used for testing apis
@@ -169,15 +221,16 @@ def test_update(self, train_config, ppo_vis):
169
221
########################################################################
170
222
def test_full_train (self , train_config , ppo ):
171
223
c = train_config
172
-
173
224
# begin training
174
225
episode , step = Counter (), Counter ()
175
226
reward_fulfilled = Counter ()
176
227
smoother = Smooth ()
177
228
terminal = False
178
229
179
230
env = c .env
180
- ppo .grad_max = 0.1
231
+ ppo .gae_lambda = 1.0
232
+ ppo .update_times = 20
233
+ ppo .entropy_weight = 1
181
234
while episode < c .max_episodes :
182
235
episode .count ()
183
236
@@ -192,7 +245,11 @@ def test_full_train(self, train_config, ppo):
192
245
old_state = state
193
246
# agent model inference
194
247
action = ppo .act ({"state" : old_state .unsqueeze (0 )})[0 ]
195
- state , reward , terminal , _ = env .step (action .cpu ().numpy ())
248
+ state , reward , terminal , _ = env .step (
249
+ action .clamp (- c .action_range , c .action_range ).cpu ()
250
+ .numpy ()
251
+ )
252
+ #state, reward, terminal, _ = env.step(action.item())
196
253
state = t .tensor (state , dtype = t .float32 , device = c .device ) \
197
254
.flatten ()
198
255
total_reward += float (reward )
@@ -207,7 +264,8 @@ def test_full_train(self, train_config, ppo):
207
264
208
265
# update
209
266
ppo .store_episode (tmp_observations )
210
- logger .info ("{:.6f}, {:.0f}" .format (* ppo .update ()))
267
+ if episode .get () % 5 == 0 :
268
+ logger .info ("{:.6f}, {:.2f}" .format (* ppo .update ()))
211
269
212
270
smoother .update (total_reward )
213
271
step .reset ()
0 commit comments