Comments (20)
No worries, I think the script below would work, using some Atari wrappers. If you are interested in running more retro experiments, feel free to join our discord channel at https://discord.gg/D6RCjA6sVT. I would be happy to feature your retro experiments in our Open RL Benchmark
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import argparse
from distutils.util import strtobool
import numpy as np
import gym
from gym.wrappers import TimeLimit, Monitor
import retro
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
from stable_baselines3.common.atari_wrappers import (
NoopResetEnv, MaxAndSkipEnv, EpisodicLifeEnv, FireResetEnv, WarpFrame, ClipRewardEnv)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env import VecFrameStack
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="Airstriker-Genesis",
help='the id of the gym environment')
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
help='the learning rate of the optimizer')
parser.add_argument('--seed', type=int, default=1,
help='seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=10000,
help='total timesteps of the experiments')
parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
help="the entity (team) of wandb's project")
# Algorithm specific arguments
parser.add_argument('--n-minibatch', type=int, default=4,
help='the number of mini batch')
parser.add_argument('--num-envs', type=int, default=8,
help='the number of parallel game environment')
parser.add_argument('--num-steps', type=int, default=128,
help='the number of steps per game environment')
parser.add_argument('--gamma', type=float, default=0.99,
help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.95,
help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
help="coefficient of the entropy")
parser.add_argument('--vf-coef', type=float, default=0.5,
help="coefficient of the value function")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.1,
help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs', type=int, default=4,
help="the K epochs to update the policy")
parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.03,
help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Use GAE for advantage computation')
parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggles advantages normalization")
parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
args = parser.parse_args()
if not args.seed:
args.seed = int(time.time())
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.n_minibatch)
class VecPyTorch(VecEnvWrapper):
def __init__(self, venv, device):
super(VecPyTorch, self).__init__(venv)
self.device = device
def reset(self):
obs = self.venv.reset()
obs = torch.from_numpy(obs).float().to(self.device)
return obs
def step_async(self, actions):
actions = actions.cpu().numpy()
self.venv.step_async(actions)
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
obs = torch.from_numpy(obs).float().to(self.device)
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
'\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
import wandb
wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
writer = SummaryWriter(f"/tmp/{experiment_name}")
# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
def make_env(gym_id, seed, idx):
def thunk():
env = retro.make(gym_id, use_restricted_actions=retro.Actions.DISCRETE)
env = MaxAndSkipEnv(env, skip=4)
env = gym.wrappers.RecordEpisodeStatistics(env)
if args.capture_video:
if idx == 0:
env = Monitor(env, f'videos/{experiment_name}')
env = WarpFrame(env, width=84, height=84)
env = ClipRewardEnv(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
return thunk
envs = VecPyTorch(VecFrameStack(
SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], start_method=("fork")),
4), device)
assert isinstance(envs.action_space, Discrete), "only discrete action space is supported"
# ALGO LOGIC: initialize agent here:
class Scale(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x):
return x * self.scale
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs, frames=4):
super(Agent, self).__init__()
self.network = nn.Sequential(
Scale(1/255),
layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(3136, 512)),
nn.ReLU()
)
self.actor = layer_init(nn.Linear(512, envs.action_space.n), std=0.01)
self.critic = layer_init(nn.Linear(512, 1), std=1)
def forward(self, x):
return self.network(x.permute((0, 3, 1, 2))) # "bhwc" -> "bchw"
def get_action(self, x, action=None):
logits = self.actor(self.forward(x))
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy()
def get_value(self, x):
return self.critic(self.forward(x))
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
if args.anneal_lr:
# https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
lr = lambda f: f * args.learning_rate
# ALGO Logic: Storage for epoch data
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = envs.reset()
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates+1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (update - 1.0) / num_updates
lrnow = lr(frac)
optimizer.param_groups[0]['lr'] = lrnow
# TRY NOT TO MODIFY: prepare the execution of the game.
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
obs[step] = next_obs
dones[step] = next_done
# ALGO LOGIC: put action logic here
with torch.no_grad():
values[step] = agent.get_value(obs[step]).flatten()
action, logproba, _ = agent.get_action(obs[step])
# visualization
if args.capture_video:
probs_list = np.array(Categorical(
logits=agent.actor(agent.forward(obs[step]))).probs[0:1].tolist())
envs.env_method("set_probs", probs_list, indices=0)
actions[step] = action
logprobs[step] = logproba
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rs, ds, infos = envs.step(action)
rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
for info in infos:
if 'episode' in info.keys():
print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
break
# bootstrap reward if not done. reached the batch limit
with torch.no_grad():
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
if args.gae:
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = last_value
else:
nextnonterminal = 1.0 - dones[t+1]
nextvalues = values[t+1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
else:
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
next_return = last_value
else:
nextnonterminal = 1.0 - dones[t+1]
next_return = returns[t+1]
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values
# flatten the batch
b_obs = obs.reshape((-1,)+envs.observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,)+envs.action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizaing the policy and value network
target_agent = Agent(envs).to(device)
inds = np.arange(args.batch_size,)
for i_epoch_pi in range(args.update_epochs):
np.random.shuffle(inds)
target_agent.load_state_dict(agent.state_dict())
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
minibatch_ind = inds[start:end]
mb_advantages = b_advantages[minibatch_ind]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
_, newlogproba, entropy = agent.get_action(b_obs[minibatch_ind], b_actions.long()[minibatch_ind])
ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()
# Stats
approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
entropy_loss = entropy.mean()
# Value loss
new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
if args.clip_vloss:
v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((new_values - b_returns[minibatch_ind]) ** 2).mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if args.kle_stop:
if approx_kl > args.target_kl:
break
if args.kle_rollback:
if (b_logprobs[minibatch_ind] - agent.get_action(b_obs[minibatch_ind], b_actions.long()[minibatch_ind])[1]).mean() > args.target_kl:
agent.load_state_dict(target_agent.state_dict())
break
# TRY NOT TO MODIFY: record rewards for plotting purposes
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
if args.kle_stop or args.kle_rollback:
writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
envs.close()
writer.close()
from cleanrl.
Hey would you mind posting more specifics here? Itβs kind of hard to help you with just these descriptions. You should specify, for example, what issue or error it was having, what does your env initialization code looks like.
from cleanrl.
Yes, I understand. I really don't have a specific example; I just wanted to know some broad adjustments needed to be made to account for Retro Environment; Since they have issues with running multiple environments once. I apologize if this was a vague post.
from cleanrl.
Above is the progress I've made on Mortal-Kombat3-genesis environment. And the top-performing agent is PPG (via your implementation), I am just trying to make more adjustments and compare them IMPALA_Cnn.
from cleanrl.
Awesome! The PPG implementation is a little unstable. I wasn't able to reproduce the authors' results so I'd be careful with that.
from cleanrl.
Yeah, it seems like it's not getting past 400 rewards no matter what I try, unfortunately.
from cleanrl.
How long did you train it for?
from cleanrl.
Well, the longest one was trained for 12 hours straight, and on average it stayed around 400. (PPG). Currently, i'm trying PPO Impala(8 workers), and it's getting the same results(trained on 10 million global_steps so far)
from cleanrl.
mk3 seems to be a hard environment I suppose
from cleanrl.
Could you share the run?
from cleanrl.
code or graph?
from cleanrl.
from cleanrl.
Graph. Or the url of your experiment
from cleanrl.
blue is ppo-impala
from cleanrl.
Looks like there are some learning happened in the beginning. So this might be just environment specific
from cleanrl.
Do you think bringing down the learning rate would make any difference? Or changing any other hyper-parameter would make any changes? Or is this a lost cause?
from cleanrl.
Short answer is I am not sure, but I think they should help
from cleanrl.
Well, I'll keep trying and if anything happens I'll post it here. I'll let PPO-impala go on until it reaches around 40~ ish million steps and sees if it improves. If anyone has any suggestions, please feel free to reply.
from cleanrl.
from cleanrl.
Closing now.
from cleanrl.
Related Issues (20)
- Gymnasium Version Requirement May Need To Be Updated HOT 1
- Gymnasium and PyBullet envs: Problem with video capture
- Docker image is out of date
- KeyError: 'cleanba_ppo_envpool_procgen'
- Unable to allocate 26.3 GiB for an array HOT 1
- Potential bugs in RecordEpisodeStatistics
- Any usage of poetry after installation: No module named 'tomli' HOT 2
- Correct handling of `termination` vs `truncation`? HOT 1
- Truncation not handled correctly when `optimize_memory_usage=True`
- Improve compatibility with CUDA-enabled pytorch on non-CUDA devices HOT 3
- Implement CrossQ HOT 1
- Adding Munchausen Reinforcement Learning
- Why doesn't SAC and DDPG support multi-environment? HOT 1
- mujoco recorded video is blank screen HOT 2
- --capture_video crashes because fps is not set
- Setup issue HOT 1
- PPO Envpool doesn't account for episode change HOT 2
- [Bug] Different network parameter on different GPU in ppo_atari_multigpu.py
- Possible inconsistencies with the PPO implementation HOT 4
- gymnasium.error.NameNotFound: Environment `BreakoutNoFrameskip` doesn't exist. HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cleanrl.