Skip to content

Commit 7fa986b

Browse files
committed
New buffer implementation.
1 parent 0b97fc7 commit 7fa986b

21 files changed

+835
-434
lines changed

machin/frame/algorithms/a2c.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,18 @@ def store_episode(self, episode: List[Union[Transition, Dict]]):
312312
+ gae_delta
313313
)
314314

315-
for trans in episode:
316-
self.replay_buffer.append(
317-
trans,
318-
required_attrs=(
319-
"state",
320-
"action",
321-
"next_state",
322-
"reward",
323-
"value",
324-
"gae",
325-
"terminal",
326-
),
327-
)
315+
self.replay_buffer.store_episode(
316+
episode,
317+
required_attrs=(
318+
"state",
319+
"action",
320+
"next_state",
321+
"reward",
322+
"value",
323+
"gae",
324+
"terminal",
325+
),
326+
)
328327

329328
def update(
330329
self, update_value=True, update_policy=True, concatenate_samples=True, **__

machin/frame/algorithms/ddpg.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,10 @@ def store_episode(self, episode: List[Union[Transition, Dict]]):
370370
"""
371371
Add a full episode of transition samples to the replay buffer.
372372
"""
373-
for trans in episode:
374-
self.replay_buffer.append(
375-
trans,
376-
required_attrs=("state", "action", "reward", "next_state", "terminal"),
377-
)
373+
self.replay_buffer.store_episode(
374+
episode,
375+
required_attrs=("state", "action", "reward", "next_state", "terminal"),
376+
)
378377

379378
def update(
380379
self,

machin/frame/algorithms/ddpg_per.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
)
8383
else:
8484
# A loss defined in ``torch.nn.modules.loss``
85-
if self.criterion.reduction != "none":
85+
if getattr(self.criterion, "reduction") != "none":
8686
default_logger.warning(
8787
"The reduction property of criterion is not 'none', "
8888
"automatically corrected."

machin/frame/algorithms/dqn.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,10 @@ def store_episode(self, episode: List[Union[Transition, Dict]]):
325325
"""
326326
Add a full episode of transition samples to the replay buffer.
327327
"""
328-
for trans in episode:
329-
self.replay_buffer.append(
330-
trans,
331-
required_attrs=("state", "action", "reward", "next_state", "terminal"),
332-
)
328+
self.replay_buffer.store_episode(
329+
episode,
330+
required_attrs=("state", "action", "reward", "next_state", "terminal"),
331+
)
333332

334333
def update(
335334
self, update_value=True, update_target=True, concatenate_samples=True, **__

machin/frame/algorithms/dqn_per.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
)
8080
else:
8181
# A loss defined in ``torch.nn.modules.loss``
82-
if self.criterion.reduction != "none":
82+
if getattr(self.criterion, "reduction") != "none":
8383
default_logger.warning(
8484
"The reduction property of criterion is not 'none', "
8585
"automatically corrected."

machin/frame/algorithms/gail.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ def store_expert_episode(self, episode: List[Union[ExpertTransition, Dict]]):
237237
238238
Only states and actions are required.
239239
"""
240-
for trans in episode:
241-
if isinstance(trans, dict):
242-
trans = ExpertTransition(**trans)
243-
self.expert_replay_buffer.append(trans, required_attrs=("state", "action"))
240+
episode = [
241+
ExpertTransition(**trans) for trans in episode if isinstance(trans, dict)
242+
]
243+
self.expert_replay_buffer.store_episode(
244+
episode, required_attrs=("state", "action")
245+
)
244246

245247
def update(
246248
self,

machin/frame/algorithms/impala.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def sample_batch(
4040
*_,
4141
**__,
4242
) -> Any:
43-
super().sample_batch(
43+
return super().sample_batch(
4444
batch_size=batch_size,
4545
concatenate=concatenate,
4646
device=device,
@@ -310,15 +310,6 @@ def update(self, update_value=True, update_policy=True, **__):
310310
" an unknown error has occurred."
311311
)
312312

313-
for major_attr in (state, action, next_state):
314-
for k, v in major_attr.items():
315-
major_attr[k] = t.cat(v, dim=0)
316-
assert major_attr[k].shape[0] == sum_length
317-
318-
terminal = t.cat(terminal, dim=0).view(sum_length, 1)
319-
reward = t.cat(reward, dim=0).view(sum_length, 1)
320-
action_log_prob = t.cat(action_log_prob, dim=0).view(sum_length, 1)
321-
322313
# Below are the v-trace process
323314

324315
# Calculate c and rho first, because there is no dependency

machin/frame/algorithms/maddpg.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77
from machin.utils.logging import default_logger
88
from machin.model.nets.base import static_module_wrapper
99
from machin.parallel.pool import P2PPool, ThreadPool
10+
from machin.frame.transition import Scalar
1011

1112
# pylint: disable=wildcard-import, unused-wildcard-import
1213
from .ddpg import *
1314

1415

1516
class SHMBuffer(Buffer):
16-
@staticmethod
17-
def make_tensor_from_batch(batch, device, concatenate):
17+
def make_tensor_from_batch(
18+
self,
19+
batch: List[Union[Scalar, t.Tensor]],
20+
device: Union[str, t.device],
21+
concatenate: bool,
22+
):
1823
# this function is used in post processing, and we will
1924
# move all cpu tensors to shared memory.
2025
if concatenate and len(batch) != 0:
@@ -307,11 +312,11 @@ def optimizers(self):
307312
def optimizers(self, optimizers):
308313
counter = 0
309314
for ac in self.actor_optims:
310-
for id, _acc in enumerate(ac):
311-
ac[id] = optimizers[counter]
315+
for i in range(len(ac)):
316+
ac[i] = optimizers[counter]
312317
counter += 1
313-
for id in range(len(self.critic_optims)):
314-
self.critic_optims[id] = optimizers[counter]
318+
for i in range(len(self.critic_optims)):
319+
self.critic_optims[i] = optimizers[counter]
315320
counter += 1
316321

317322
@property
@@ -506,18 +511,11 @@ def store_episodes(self, episodes: List[List[Union[Transition, Dict]]]):
506511
assert len(episodes) == len(self.replay_buffers)
507512
all_length = [len(ep) for ep in episodes]
508513
assert len(set(all_length)) == 1, "All episodes must have the same length!"
509-
for buff, ep in zip(self.replay_buffers, episodes):
510-
for trans in ep:
511-
buff.append(
512-
trans,
513-
required_attrs=(
514-
"state",
515-
"action",
516-
"next_state",
517-
"reward",
518-
"terminal",
519-
),
520-
)
514+
for buffer, episode in zip(self.replay_buffers, episodes):
515+
buffer.store_episode(
516+
episode,
517+
required_attrs=("state", "action", "next_state", "reward", "terminal",),
518+
)
521519

522520
def update(
523521
self,
@@ -961,7 +959,7 @@ def _check_parameters_device(models):
961959
def _create_sample_method(indexes):
962960
def sample_method(buffer, _len):
963961
nonlocal indexes
964-
batch = [buffer[i] for i in indexes if i < len(buffer)]
962+
batch = [buffer.storage[i] for i in indexes if i < buffer.size()]
965963
return len(batch), batch
966964

967965
return sample_method

machin/frame/algorithms/rainbow.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,17 @@ def store_episode(self, episode: List[Union[Transition, Dict]]):
188188
value_sum = value_sum * self.discount + episode[i + j]["reward"]
189189
episode[i]["value"] = value_sum
190190

191-
for trans in episode:
192-
self.replay_buffer.append(
193-
trans,
194-
required_attrs=(
195-
"state",
196-
"action",
197-
"next_state",
198-
"reward",
199-
"value",
200-
"terminal",
201-
),
202-
)
191+
self.replay_buffer.store_episode(
192+
episode,
193+
required_attrs=(
194+
"state",
195+
"action",
196+
"next_state",
197+
"reward",
198+
"value",
199+
"terminal",
200+
),
201+
)
203202

204203
def update(
205204
self, update_value=True, update_target=True, concatenate_samples=True, **__

machin/frame/algorithms/sac.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,10 @@ def store_episode(self, episode: List[Union[Transition, Dict]]):
271271
"""
272272
Add a full episode of transition samples to the replay buffer.
273273
"""
274-
for trans in episode:
275-
self.replay_buffer.append(
276-
trans,
277-
required_attrs=("state", "action", "next_state", "reward", "terminal"),
278-
)
274+
self.replay_buffer.store_episode(
275+
episode,
276+
required_attrs=("state", "action", "next_state", "reward", "terminal"),
277+
)
279278

280279
def update(
281280
self,
@@ -395,7 +394,7 @@ def update(
395394
self.critic.eval()
396395
self.critic2.eval()
397396
# use .item() to prevent memory leakage
398-
return (-act_policy_loss.item(), (value_loss.item() + value_loss2.item()) / 2)
397+
return -act_policy_loss.item(), (value_loss.item() + value_loss2.item()) / 2
399398

400399
def update_lr_scheduler(self):
401400
"""

machin/frame/algorithms/td3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def update(
257257
self.critic.eval()
258258
self.critic2.eval()
259259
# use .item() to prevent memory leakage
260-
return (-act_policy_loss.item(), (value_loss.item() + value_loss2.item()) / 2)
260+
return -act_policy_loss.item(), (value_loss.item() + value_loss2.item()) / 2
261261

262262
@staticmethod
263263
def policy_noise_function(actions, *_):

machin/frame/buffers/buffer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(
1515
buffer_size: int = 1000000,
1616
buffer_device: Union[str, t.device] = "cpu",
1717
storage: TransitionStorageBase = None,
18-
*_,
1918
**__,
2019
):
2120
"""
@@ -79,15 +78,15 @@ def store_episode(
7978
elif isinstance(transition, TransitionBase):
8079
pass
8180
else: # pragma: no cover
82-
raise RuntimeError(
81+
raise ValueError(
8382
"Transition object must be a dict or an instance"
84-
" of the Transition class"
83+
" of the Transition class."
8584
)
8685
if not transition.has_keys(required_attrs):
8786
missing_keys = set(required_attrs) - set(transition.keys())
8887
raise ValueError(
8988
f"Transition object missing attributes: {missing_keys}, "
90-
f"object is {transition}"
89+
f"object is {transition}."
9190
)
9291
episode[idx] = transition
9392

machin/frame/buffers/buffer_d.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def __init__(
2121
group: RpcGroup,
2222
buffer_size: int = 1000000,
2323
storage: TransitionStorageBase = None,
24-
*_,
25-
**__,
24+
**kwargs,
2625
):
2726
"""
2827
Create a distributed replay buffer instance.
@@ -58,7 +57,9 @@ def __init__(
5857
storage: Custom storage, not compatible with `buffer_size` and
5958
`buffer_device`.
6059
"""
61-
super().__init__(buffer_size=buffer_size, buffer_device="cpu", storage=storage)
60+
super().__init__(
61+
buffer_size=buffer_size, buffer_device="cpu", storage=storage, **kwargs
62+
)
6263
self.buffer_name = buffer_name
6364
self.group = group
6465

0 commit comments

Comments
 (0)