Comments (4)
Tutorial code is a combination of TF-agent 1_dqn_tutorial
and BTgym unreal_stacked_lstm_strat_4_11
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
warnings.filterwarnings("ignore") # suppress h5py deprecation warning
import os
import backtrader as bt
import numpy as np
from btgym import BTgymEnv, BTgymDataset
from btgym.strategy.observers import Reward, Position, NormPnL
from btgym.research.strategy_gen_4 import DevStrat_4_11
# Set backtesting engine parameters:
MyCerebro = bt.Cerebro()
# Define strategy and broker account parameters:
MyCerebro.addstrategy(
DevStrat_4_11,
start_cash=2000, # initial broker cash
commission=0.0001, # commisssion to imitate spread
leverage=10.0,
order_size=2000, # fixed stake, mind leverage
drawdown_call=10, # max % to loose, in percent of initial cash
target_call=10, # max % to win, same
skip_frame=10,
gamma=0.99,
reward_scale=7, # gardient`s nitrox, touch with care!
state_ext_scale = np.linspace(3e3, 1e3, num=5)
)
# Visualisations for reward, position and PnL dynamics:
MyCerebro.addobserver(Reward)
MyCerebro.addobserver(Position)
MyCerebro.addobserver(NormPnL)
# Data: uncomment to get up to six month of 1 minute bars:
data_m1_6_month = [
'./data/DAT_ASCII_EURUSD_M1_201701.csv',
'./data/DAT_ASCII_EURUSD_M1_201702.csv',
'./data/DAT_ASCII_EURUSD_M1_201703.csv',
'./data/DAT_ASCII_EURUSD_M1_201704.csv',
'./data/DAT_ASCII_EURUSD_M1_201705.csv',
'./data/DAT_ASCII_EURUSD_M1_201706.csv',
]
# Uncomment single choice:
MyDataset = BTgymDataset(
#filename=data_m1_6_month,
filename='./data/test_sine_1min_period256_delta0002.csv', # simple sine
start_weekdays={0, 1, 2, 3, 4, 5, 6},
episode_duration={'days': 1, 'hours': 23, 'minutes': 40}, # note: 2day-long episode
start_00=False,
time_gap={'hours': 10},
)
btgym_env = BTgymEnv(
dataset=MyDataset,
engine=MyCerebro,
render_modes=['episode', 'human', 'internal', ], # 'external'],
render_state_as_image=True,
render_ylabel='OHL_diff. / Internals',
render_size_episode=(12, 8),
render_size_human=(9, 4),
render_size_state=(11, 3),
render_dpi=75,
port=5000,
data_port=4999,
connect_timeout=90,
verbose=0,
)
import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay
import tensorflow as tf
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
from tf_agents.environments import suite_gym
num_iterations = 20000 # @param {type:"integer"}
initial_collect_steps = 1000 # @param {type:"integer"}
collect_steps_per_iteration = 1 # @param {type:"integer"}
replay_buffer_max_length = 10000 # @param {type:"integer"}
batch_size = 64 # @param {type:"integer"}
learning_rate = 1e-3 # @param {type:"number"}
log_interval = 200 # @param {type:"integer"}
num_eval_episodes = 1 # @param {type:"integer"}
eval_interval = 1000 # @param {type:"integer"}
train_py_env = suite_gym.wrap_env(
gym_env=btgym_env,
#discount=1.0,
#max_episode_steps=0,
#gym_env_wrappers=(),
#time_limit_wrapper=wrappers.TimeLimit,
#env_wrappers=(),
#spec_dtype_map=None,
#auto_reset=True
)
eval_py_env = suite_gym.wrap_env(
gym_env=btgym_env,
#discount=1.0,
#max_episode_steps=0,
#gym_env_wrappers=(),
#time_limit_wrapper=wrappers.TimeLimit,
#env_wrappers=(),
#spec_dtype_map=None,
#auto_reset=True
)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
observation_spec = train_env.observation_spec()
action_spec = train_env.action_spec()
print('Observation Spec:')
print(observation_spec)
print('Action Spec:')
print(action_spec)
fc_layer_params = (100,)
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
#preprocessing_combiner=tf.keras.layers.Concatenate(axis=-1),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
eval_policy = agent.policy
collect_policy = agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())
#time_step = train_env.reset()
#action = random_policy.action(time_step)
#print(action)
#print('ok')
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_max_length
)
def compute_avg_return(environment, policy, num_episodes=10):
import math
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
def collect_step(environment, policy, buffer):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
# Add trajectory to the replay buffer
buffer.add_batch(traj)
def collect_data(env, policy, buffer, steps):
for _ in range(steps):
collect_step(env, policy, buffer)
collect_data(train_env, random_policy, replay_buffer, steps=100)
#print(replay_buffer)
#print('ok')
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
# # Reset the train step
agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]
print(returns)
for _ in range(num_iterations):
# TODO: added this reset cause when trying to collect_step for buffer it showed btgym the 'did reset hint' error
# maybe it should be here in some form but not sure. for now just reset
train_env.reset()
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
collect_step(train_env, agent.collect_policy, replay_buffer)
#print(train_env.action_spec())
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience).loss
step = agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
print('Done')
from btgym.
@JacobHanouna, Would you like to prepare pull request with tutorial notebook?
from btgym.
I can make a notebook tutorial, but because there is some code that need be change for it to work. I thought it would be an issue, this is why I posted it here.
from btgym.
Ok, I see.
from btgym.
Related Issues (20)
- Is there any real-life cases of successful application of reinforcement learning in trading / asset management? HOT 4
- Overestimated Value Function in Actor Critic Framework HOT 7
- signal.pause() - workers exit, but signal never received -- software issue? (debian linux) HOT 16
- loading multiple features - question ? HOT 3
- Amazing project <3
- PR Request for Docker addition HOT 2
- Train Test routine sampling - IndexError HOT 2
- BTgymMultiData - Sync between different data stream HOT 5
- Discussion: Long Episode Duration HOT 3
- Erroneous static_RNN policy behavior explanation.
- 2020
- BTGym Slack Join Link Broken HOT 1
- Problem with dependencies in installation on window HOT 1
- Examples that do more that randomly selects an action?
- Support Tensorflow 2 HOT 14
- ValueError: Axis limits cannot be NaN or Inf HOT 1
- INFOS
- Use btgym custom environment
- _pickle.PicklingError: Can't pickle <class 'pandas.core.frame.Pandas'>: attribute lookup Pandas on pandas.core.frame failed HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from btgym.