|
7 | 7 | from machin.utils.logging import default_logger
|
8 | 8 | from machin.model.nets.base import static_module_wrapper
|
9 | 9 | from machin.parallel.pool import P2PPool, ThreadPool
|
| 10 | +from machin.frame.transition import Scalar |
10 | 11 |
|
11 | 12 | # pylint: disable=wildcard-import, unused-wildcard-import
|
12 | 13 | from .ddpg import *
|
13 | 14 |
|
14 | 15 |
|
15 | 16 | class SHMBuffer(Buffer):
|
16 |
| - @staticmethod |
17 |
| - def make_tensor_from_batch(batch, device, concatenate): |
| 17 | + def make_tensor_from_batch( |
| 18 | + self, |
| 19 | + batch: List[Union[Scalar, t.Tensor]], |
| 20 | + device: Union[str, t.device], |
| 21 | + concatenate: bool, |
| 22 | + ): |
18 | 23 | # this function is used in post processing, and we will
|
19 | 24 | # move all cpu tensors to shared memory.
|
20 | 25 | if concatenate and len(batch) != 0:
|
@@ -307,11 +312,11 @@ def optimizers(self):
|
307 | 312 | def optimizers(self, optimizers):
|
308 | 313 | counter = 0
|
309 | 314 | for ac in self.actor_optims:
|
310 |
| - for id, _acc in enumerate(ac): |
311 |
| - ac[id] = optimizers[counter] |
| 315 | + for i in range(len(ac)): |
| 316 | + ac[i] = optimizers[counter] |
312 | 317 | counter += 1
|
313 |
| - for id in range(len(self.critic_optims)): |
314 |
| - self.critic_optims[id] = optimizers[counter] |
| 318 | + for i in range(len(self.critic_optims)): |
| 319 | + self.critic_optims[i] = optimizers[counter] |
315 | 320 | counter += 1
|
316 | 321 |
|
317 | 322 | @property
|
@@ -506,18 +511,11 @@ def store_episodes(self, episodes: List[List[Union[Transition, Dict]]]):
|
506 | 511 | assert len(episodes) == len(self.replay_buffers)
|
507 | 512 | all_length = [len(ep) for ep in episodes]
|
508 | 513 | assert len(set(all_length)) == 1, "All episodes must have the same length!"
|
509 |
| - for buff, ep in zip(self.replay_buffers, episodes): |
510 |
| - for trans in ep: |
511 |
| - buff.append( |
512 |
| - trans, |
513 |
| - required_attrs=( |
514 |
| - "state", |
515 |
| - "action", |
516 |
| - "next_state", |
517 |
| - "reward", |
518 |
| - "terminal", |
519 |
| - ), |
520 |
| - ) |
| 514 | + for buffer, episode in zip(self.replay_buffers, episodes): |
| 515 | + buffer.store_episode( |
| 516 | + episode, |
| 517 | + required_attrs=("state", "action", "next_state", "reward", "terminal",), |
| 518 | + ) |
521 | 519 |
|
522 | 520 | def update(
|
523 | 521 | self,
|
@@ -961,7 +959,7 @@ def _check_parameters_device(models):
|
961 | 959 | def _create_sample_method(indexes):
|
962 | 960 | def sample_method(buffer, _len):
|
963 | 961 | nonlocal indexes
|
964 |
| - batch = [buffer[i] for i in indexes if i < len(buffer)] |
| 962 | + batch = [buffer.storage[i] for i in indexes if i < buffer.size()] |
965 | 963 | return len(batch), batch
|
966 | 964 |
|
967 | 965 | return sample_method
|
|
0 commit comments