import torch as th
from rllte.xplore.reward import RND
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG
if __name__ == '__main__':
n_steps: int = 2048 * 16
device = 'cuda' if th.cuda.is_available() else 'cpu'
envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
asynchronous=False, frame_stack=4, gray_scale=True)
print(device, envs.observation_space, envs.action_space)
# create the intrinsic reward module
irs = RND(envs, device=device)
# create the PPO agent
agent = PPO(envs, device=device)
# set the intrinsic reward module
agent.set(reward=irs)
# train the agent
agent.train(n_steps * 153, eval_interval=n_steps // 8, save_interval=n_steps)
/opt/conda/lib/python3.10/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
logger.warn(
/opt/conda/lib/python3.10/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
logger.warn(
/opt/conda/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
logger.warn(
cuda Box(0, 255, (4, 84, 84), uint8) Discrete(7)
/opt/conda/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)
if not isinstance(terminated, (bool, np.bool8)):
[05/24/2024 04:14:52 PM] - [INFO.] - Invoking RLLTE Engine...
[05/24/2024 04:14:52 PM] - [INFO.] - ================================================================================
[05/24/2024 04:14:52 PM] - [INFO.] - Tag : default
[05/24/2024 04:14:52 PM] - [INFO.] - Device : NVIDIA A100-SXM4-40GB
[05/24/2024 04:14:52 PM] - [DEBUG] - Agent : PPO
[05/24/2024 04:14:52 PM] - [DEBUG] - Encoder : MnihCnnEncoder
[05/24/2024 04:14:52 PM] - [DEBUG] - Policy : OnPolicySharedActorCritic
[05/24/2024 04:14:52 PM] - [DEBUG] - Storage : VanillaRolloutStorage
[05/24/2024 04:14:52 PM] - [DEBUG] - Distribution : Categorical
[05/24/2024 04:14:52 PM] - [DEBUG] - Augmentation : None
[05/24/2024 04:14:52 PM] - [DEBUG] - Intrinsic Reward : RND
[05/24/2024 04:14:52 PM] - [DEBUG] - ================================================================================
Traceback (most recent call last):
File "/workdir/got-it-memorized/src/run_rnd2.py", line 20, in <module>
agent.train(n_steps * 153, eval_interval=n_steps // 8, save_interval=n_steps)
File "/opt/conda/lib/python3.10/site-packages/rllte/common/prototype/on_policy_agent.py", line 105, in train
obs, infos = self.env.reset(seed=self.seed)
File "/opt/conda/lib/python3.10/site-packages/rllte/env/utils.py", line 152, in reset
obs, infos = self.env.reset(seed=seed, options=options)
File "/opt/conda/lib/python3.10/site-packages/gymnasium/wrappers/record_episode_statistics.py", line 78, in reset
obs, info = super().reset(**kwargs)
File "/opt/conda/lib/python3.10/site-packages/gymnasium/core.py", line 467, in reset
return self.env.reset(seed=seed, options=options)
File "/opt/conda/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 140, in reset
return self.reset_wait(seed=seed, options=options)
File "/opt/conda/lib/python3.10/site-packages/gymnasium/vector/sync_vector_env.py", line 122, in reset_wait
observation, info = env.reset(**kwargs)
ValueError: too many values to unpack (expected 2)