ntt123 / a0-jax Goto Github PK
View Code? Open in Web Editor NEWAlphaZero in JAX
Home Page: https://go.ntt123.repl.co
License: MIT License
AlphaZero in JAX
Home Page: https://go.ntt123.repl.co
License: MIT License
On a budget, I'm running the training_agent
for Caro on Colab with TPU.
However, somehow it always got killed at iteration #1 around 64% without much stacktraces provided.
Any experiences or theories on why this may happen?
!TF_CPP_MIN_LOG_LEVEL=0
!time python3 train_agent.py \
--game-class="caro_game.CaroGame" \
--agent-class="resnet_policy.ResnetPolicyValueNet128" \
--selfplay-batch-size=1024 \
--training-batch-size=1024 \
--num-simulations-per-move=32 \
--num-self-plays-per-iteration=102400 \
--learning-rate=1e-2 \
--random-seed=42 \
--ckpt-filename="./caro_agent_9x9_128.ckpt" \
--num-iterations=100 \
--lr-decay-steps=500000
2022-11-25 08:59:37.077139: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cores: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Loading weights at ./caro_agent_9x9_128.ckpt
Iteration 1
self play [######################--------------] 63% 00:09:41 /bin/bash: line 1: 2377 Killed python3 train_agent.py --game-class="caro_game.CaroGame" --agent-class="resnet_policy.ResnetPolicyValueNet128" --selfplay-batch-size=1024 --training-batch-size=1024 --num-simulations-per-move=32 --num-self-plays-per-iteration=102400 --learning-rate=1e-2 --random-seed=42 --ckpt-filename="./caro_agent_9x9_128.ckpt" --num-iterations=100 --lr-decay-steps=500000
real 17m19.797s
user 10m5.645s
sys 5m3.467s
I've encountered a containerization issue when tried to implement a new environment that calls external application for game logic. I would need to call in step
to get a new state, but at this point action is batched tracer so I can't extract it's value with call
because batched input doesn't implement it.
class CheckersGame(Environment):
...
def _step(self, action: chex.Array) -> Tuple["CheckersGame", chex.Array]:
action = self._prepare_action(action) # get a concrete value of action
new_state, reward = call_external_env(action)
return self, jnp.array(reward, dtype=jnp.int32)
@pax.pure
def step(self, action: chex.Array) -> Tuple["CheckersGame", chex.Array]:
# batched action comes in, but concrete value is required
env, reward = jax.vmap(lambda a: self._step(a))(action.reshape(-1, 1))
return self, reward
...
I can tap into action with id_print
, id_tap
here, but can't block _step
that way.
What's correct way to do that?
Thanks for the nice project.
Have you tried using the default qtransform_completed_by_mix_value for the gumbel_muzero_policy?
The qtransform_by_min_max gives zero values to unvisited actions. That does not have a good theoretical justification.
I've implemented a game which doesn't have a strictly alternating turn order (some actions change player, others don't). How could this be used in your framework? I think it's the discount, but wanted to check. Should the discount returned be 1 for any action that doesn't change player and -1 otherwise?
It's a great job! I learned a lot in your repo. Where can I find the implementation of Muzero using mctx? Thanks a lot.
the 9x9 go agent is pretty strong! how many iterations was it trained on? how long does it take to train (i saw it's on TPUs)?
Hello Mr NTT,
Can I contact you directly? It's about AlphaZero.
My email is [email protected]
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.