Comments (19)
In the link you provided, there's an example of
Dict
observation space, but what aboutDict
action spaces?
These two are the same since we use the same pipeline to deal with all the keys and data in Batch
.
Or perhaps I just missed something on the page?
It would be really cool if you'd add an example of how
Dict
action spaces can be used, including e.g. what the output ofnn.Module.forward()
should be.
Sure.
In [1]: from tianshou.data import ReplayBuffer
In [2]: buf = ReplayBuffer(10)
In [3]: import numpy as np
In [4]: buf.add(obs={'a': 1, 'b': 'str', 'c': np.array([1,2,3])}, act={'ar': 1, 'as': '666', 'ap': np.array([2,3,4])}, rew=0, done=0)
In [5]: buf
Out[5]:
ReplayBuffer(
obs: Batch(
a: array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
b: array(['str', None, None, None, None, None, None, None, None, None],
dtype=object),
c: array([[1, 2, 3],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]),
),
act: Batch(
ar: array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
as: array(['666', None, None, None, None, None, None, None, None, None],
dtype=object),
ap: array([[2, 3, 4],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]),
),
rew: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
done: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
obs_next: Batch(),
info: Batch(),
policy: Batch(),
)
In [6]: buf[0].act
Out[6]:
Batch(
ar: 1,
as: '666',
ap: array([2, 3, 4]),
)
In [7]: buf[0].act.ap
Out[7]: array([2, 3, 4])
In [8]: buf.act.ap[0]
Out[8]: array([2, 3, 4])
Your network output should be something like Batch(act={'ar': xxx, 'as': xxx, 'ap': xxx}, xxxxxx)
from tianshou.
About Batch of Batch: what's the use case of a nested batch?
If you have a dict of obs, and you want to get a slice from it.
I update a tutorial on how to customize the self-defined state and env: https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation
from tianshou.
It means your network's output has 3 different parts, and you can organize them into a dict/batch to return. As for how you generate these three parts, it depends on your network and policy -- feel free to customize them.
from tianshou.
Thanks. I will add this feature today.
from tianshou.
thanks, much appreciated!
from tianshou.
@LinxiFan I just add the draft implementation. Currently, it supports an obs which is a dict of (int / np.array).
For example, FetchPickAndPlace-v1
has three elements in obs, "achieved_goal", "desired_goal", and "observation", each one is np.ndarray. In the newest replay buffer implementation, they will store in "_obs@achieved_goal", "_obs@desired_goal", and "_obs@observation".
But nevermind, this is just the inner representation. When you sample a data batch from this replay buffer, it will automatically get an obs like:
Batch(
done: array([False, False]),
info: {'is_success': array([0., 0.])},
obs: {'achieved_goal': array([[1.24767056, 0.67446261, 0.42473605],
[1.39146099, 0.83934225, 0.42473605]]),
'desired_goal': array([[1.25585192, 0.69931522, 0.78731941],
[1.40141382, 0.61938185, 0.57598463]]),
'observation': array([[ 1.36945404e+00, 7.78820325e-01, 5.64114120e-01,
1.24767056e+00, 6.74462609e-01, 4.24736048e-01,
-1.21783482e-01, -1.04357715e-01, -1.39378071e-01,
3.97852140e-02, 4.16259342e-02, -3.85214084e-07,
5.92637053e-07, 1.12208536e-13, -2.49958924e-02,
-2.62562578e-02, -2.64169060e-02, 1.87589293e-07,
-2.88598912e-07, 5.83381976e-19, 2.49958852e-02,
2.62562532e-02, 2.64435715e-02, 6.99533302e-02,
7.08374623e-02],
[ 1.36945404e+00, 7.78820325e-01, 5.64114120e-01,
1.39146099e+00, 8.39342251e-01, 4.24736048e-01,
2.20069458e-02, 6.05219264e-02, -1.39378071e-01,
3.97852140e-02, 4.16259342e-02, -3.85214084e-07,
5.92637053e-07, 1.12208536e-13, -2.49958924e-02,
-2.62562578e-02, -2.64169060e-02, 1.87589293e-07,
-2.88598912e-07, -3.26321805e-18, 2.49958852e-02,
2.62562532e-02, 2.64435715e-02, 6.99533302e-02,
7.08374623e-02]])},
rew: array([-1., -1.], dtype=float32),
)
Here, batch.obs['achieved_goal']
, batch.obs['desired_goal']
, batch.obs['observation']
is what you want.
And in your network side, you can receive the batch formatted as above. E.g.
class NN(nn.Module):
def forward(self, batch, ...):
s = batch.obs['observation']
...
from tianshou.
There are still a lot of things to be improved. I'll continue to work on this issue.
from tianshou.
Now it supports a Batch of Batch:
Batch(
done: array([False, False]),
info: array([{'is_success': 0.0}, {'is_success': 0.0}], dtype=object),
obs: Batch(
achieved_goal: array([[1.42853749, 0.63666553, 0.42473605],
[1.20106979, 0.77055984, 0.42473605]]),
desired_goal: array([[1.47228501, 0.88356362, 0.4490597 ],
[1.37898198, 0.71865667, 0.42469975]]),
observation: array([[ 1.36945404e+00, 7.78820325e-01, 5.64114120e-01,
1.42853749e+00, 6.36665526e-01, 4.24736048e-01,
5.90834503e-02, -1.42154798e-01, -1.39378071e-01,
3.97852140e-02, 4.16259342e-02, -3.85214084e-07,
5.92637053e-07, 1.12208536e-13, -2.49958924e-02,
-2.62562578e-02, -2.64169060e-02, 1.87589293e-07,
-2.88598912e-07, 1.30443021e-18, 2.49958852e-02,
2.62562532e-02, 2.64435715e-02, 6.99533302e-02,
7.08374623e-02],
[ 1.36945404e+00, 7.78820325e-01, 5.64114120e-01,
1.20106979e+00, 7.70559842e-01, 4.24736048e-01,
-1.68384253e-01, -8.26048228e-03, -1.39378071e-01,
3.97852140e-02, 4.16259342e-02, -3.85214084e-07,
5.92637053e-07, 1.12208536e-13, -2.49958924e-02,
-2.62562578e-02, -2.64169060e-02, 1.87589293e-07,
-2.88598912e-07, -3.26321805e-18, 2.49958852e-02,
2.62562532e-02, 2.64435715e-02, 6.99533302e-02,
7.08374623e-02]]),
),
rew: array([-1., -1.], dtype=float32),
)
You can either use batch.obs['observation']
or batch.obs.observation
, and also
batch[0].obs.observation
,
batch.obs[0].observation
and
batch.obs.observation[0]
to get a slice from this observation. The last one is better because it needs the minimum data copies.
from tianshou.
Thank you so much for implementing this!
About Batch of Batch: what's the use case of a nested batch? At least for me, a one-level collated batch is good enough.
from tianshou.
I see, right, that's how you get the dot-access syntax. Thanks!
from tianshou.
Does Tianshou support Tuple/Dict action spaces as well?
from tianshou.
Dict space is naturally supported (refer to https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation and https://tianshou.readthedocs.io/en/master/tutorials/batch.html). Tuple space is not recommended, as discussed in #172.
from tianshou.
In the link you provided, there's an example of Dict
observation space, but what about Dict
action spaces? Or perhaps I just missed something on the page?
It would be really cool if you'd add an example of how Dict
action spaces can be used, including e.g. what the output of nn.Module.forward()
should be.
from tianshou.
Hi there. May I ask what the meaning of Dict actions here is? {'ar', 'as', 'ap'} means the network has 3 outputs, or there are 3 networks?
from tianshou.
Thanks for your prompt response.
I met an issue attached. In fact, I have a customized Env, and the action space is
spaces.Dict({"act_indicator": spaces.Discrete(), "act_txpower": spaces.Box(low, high)})
I hope the output of the network is a continuous scalar (data rate). Then, I connect a mapping function to map the data rate to "act_indicator" and "act_txpower". Thus, I inset the mapping function into the forward()
of ddpg. The error seems to incur by the scalar "act_indicator". However, I can't convert the scalar to np.array successfully.
from tianshou.
The output batch
from policy.forward is a "batch" -- the first dimension is always batch size in every element of batch
. indicator should have shape (bsz,)
with dtype=int64
, and txpower should have shape (bsz, 1)
with dtype=float32
.
from tianshou.
Right. I have checked the buf.act and I think they satisfy what you mentioned. When I call policy.learn(buf, batch_size=1, repeat=1)
, the issue is still here.
TypeError: Object 2 in Batch(
act_indicator: 2,
act_txpower: array([0.00846217], dtype=float32),
) has no len()
from tianshou.
So could you please paste a minimal example to demonstrate this error here?
BTW, it's better to open another issue.
from tianshou.
So could you please paste a minimal example to demonstrate this error here? BTW, it's better to open another issue.
Thanks for your time. Sorry, I think I may need more time since another issue has appeared.
Okay, I will do that.
from tianshou.
Related Issues (20)
- Question: Is Recurrent net supported for FQF
- Chinese document pages return 404 HOT 4
- data recording and saving method HOT 4
- Typing annotations of step from MyTestEnv is incompatible with its current subclass gym.Env because it can generate non-scalar rewards.
- How to monitor the episode/epoch return/length in Tianshou? HOT 1
- Replicating results in collect random operations through seed setting HOT 2
- Batch: deprecate setattr HOT 1
- Batch: don't create new objects on getitem HOT 8
- Batch: only allow entries with the same length HOT 3
- Batch: don't just set 0 when elements have None entries HOT 8
- Batch: don't just strip off empty entries when creating batches HOT 5
- Buffer: fix discrepancy in slicing order HOT 2
- Better interfaces and names for Actor, Critic, Net and other classes
- Reduce duplication between examples/atari/atari_network and examples/vizdoom/network HOT 1
- Fix docstring in BranchingNet
- Re-examine the need of utils.net.common.DataParallelNet
- Re-examine the whole state story for RNNs
- Don't pass envpool envs where vectorenvs are needed
- Missing Link HOT 5
- AttributeError: 'PPOPolicy' object has no attribute 'set_eps' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tianshou.