Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using NGU Explore with Stable-Baselines3's MultiInputPolicy on Gymnasium-Robotics #25

Open
haneenhassen opened this issue Nov 19, 2024 · 1 comment

Comments

@haneenhassen
Copy link

haneenhassen commented Nov 19, 2024

Hi,

I'm working on a project that involves using a Stable-Baselines3 agent with a multi-input policy in a Gymnasium-Robotics environment. I'm interested in incorporating the NGU exploration method.

To handle the environment's multi-input nature, the observation space is defined as a 'space.Dict' with multiple keys.

Question:
Is it possible to use NGU Explore with Stable-Baselines3's MultiInputPolicy when the observation space is a spaces.Dict?
If so, what steps are involved in integrating NGU Explore with this setup (normalize observation and reward)?

Error Encountered:
```python
Traceback (most recent call last):
File "/home/bmahdy/Master/M_NGU/RLeXplore-main/rlexplore_with_sb3_FetchPickAndPlace.py", line 115, in
irs = NGU(envs=envs, device=device, batch_size=batch)
File "/home/bmahdy/anaconda3/envs/NGU_9/lib/python3.9/site-packages/rllte/xplore/reward/ngu.py", line 87, in init
rnd = RND(
File "/home/bmahdy/anaconda3/envs/NGU_9/lib/python3.9/site-packages/rllte/xplore/reward/rnd.py", line 78, in init
super().init(envs, device, beta, kappa, gamma, rwd_norm_type, obs_norm_type)
File "/home/bmahdy/anaconda3/envs/NGU_9/lib/python3.9/site-packages/rllte/common/prototype/base_reward.py", line 88, in init
TorchRunningMeanStd(shape=self.obs_shape)
File "/home/bmahdy/anaconda3/envs/NGU_9/lib/python3.9/site-packages/rllte/common/utils.py", line 52, in init
self.mean = th.zeros(shape, device=device)
TypeError: zeros() received an invalid combination of arguments - got (dict, device=NoneType), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
* (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
'''

Any insights, guidance, or code examples would be greatly appreciated.

Thanks in advance!

@yuanmingqi
Copy link
Collaborator

This is because here, the TorchRunningMeanStd is only for simple tensors, but you are using Dict observations. You may need to rewrite the TorchRunningMeanStd class to adapt to your Dict obs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants