Git Product home page Git Product logo

Comments (19)

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024 3

In the link you provided, there's an example of Dict observation space, but what about Dict action spaces?

These two are the same since we use the same pipeline to deal with all the keys and data in Batch.

Or perhaps I just missed something on the page?

#172 (comment)

It would be really cool if you'd add an example of how Dict action spaces can be used, including e.g. what the output of nn.Module.forward() should be.

Sure.

In [1]: from tianshou.data import ReplayBuffer

In [2]: buf = ReplayBuffer(10)

In [3]: import numpy as np

In [4]: buf.add(obs={'a': 1, 'b': 'str', 'c': np.array([1,2,3])}, act={'ar': 1, 'as': '666', 'ap': np.array([2,3,4])}, rew=0, done=0)

In [5]: buf
Out[5]: 
ReplayBuffer(
    obs: Batch(
             a: array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
             b: array(['str', None, None, None, None, None, None, None, None, None],
                      dtype=object),
             c: array([[1, 2, 3],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0]]),
         ),
    act: Batch(
             ar: array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
             as: array(['666', None, None, None, None, None, None, None, None, None],
                       dtype=object),
             ap: array([[2, 3, 4],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 0]]),
         ),
    rew: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    done: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    obs_next: Batch(),
    info: Batch(),
    policy: Batch(),
)

In [6]: buf[0].act
Out[6]: 
Batch(
    ar: 1,
    as: '666',
    ap: array([2, 3, 4]),
)

In [7]: buf[0].act.ap
Out[7]: array([2, 3, 4])

In [8]: buf.act.ap[0]
Out[8]: array([2, 3, 4])

Your network output should be something like Batch(act={'ar': xxx, 'as': xxx, 'ap': xxx}, xxxxxx)

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024 1

About Batch of Batch: what's the use case of a nested batch?

If you have a dict of obs, and you want to get a slice from it.

I update a tutorial on how to customize the self-defined state and env: https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024 1

It means your network's output has 3 different parts, and you can organize them into a dict/batch to return. As for how you generate these three parts, it depends on your network and policy -- feel free to customize them.

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

Thanks. I will add this feature today.

from tianshou.

DrJimFan avatar DrJimFan commented on July 16, 2024

thanks, much appreciated!

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

@LinxiFan I just add the draft implementation. Currently, it supports an obs which is a dict of (int / np.array).
For example, FetchPickAndPlace-v1 has three elements in obs, "achieved_goal", "desired_goal", and "observation", each one is np.ndarray. In the newest replay buffer implementation, they will store in "_obs@achieved_goal", "_obs@desired_goal", and "_obs@observation".
But nevermind, this is just the inner representation. When you sample a data batch from this replay buffer, it will automatically get an obs like:

Batch(
    done: array([False, False]),
    info: {'is_success': array([0., 0.])},
    obs: {'achieved_goal': array([[1.24767056, 0.67446261, 0.42473605],
                [1.39146099, 0.83934225, 0.42473605]]),
          'desired_goal': array([[1.25585192, 0.69931522, 0.78731941],
                [1.40141382, 0.61938185, 0.57598463]]),
          'observation': array([[ 1.36945404e+00,  7.78820325e-01,  5.64114120e-01,
                  1.24767056e+00,  6.74462609e-01,  4.24736048e-01,
                 -1.21783482e-01, -1.04357715e-01, -1.39378071e-01,
                  3.97852140e-02,  4.16259342e-02, -3.85214084e-07,
                  5.92637053e-07,  1.12208536e-13, -2.49958924e-02,
                 -2.62562578e-02, -2.64169060e-02,  1.87589293e-07,
                 -2.88598912e-07,  5.83381976e-19,  2.49958852e-02,
                  2.62562532e-02,  2.64435715e-02,  6.99533302e-02,
                  7.08374623e-02],
                [ 1.36945404e+00,  7.78820325e-01,  5.64114120e-01,
                  1.39146099e+00,  8.39342251e-01,  4.24736048e-01,
                  2.20069458e-02,  6.05219264e-02, -1.39378071e-01,
                  3.97852140e-02,  4.16259342e-02, -3.85214084e-07,
                  5.92637053e-07,  1.12208536e-13, -2.49958924e-02,
                 -2.62562578e-02, -2.64169060e-02,  1.87589293e-07,
                 -2.88598912e-07, -3.26321805e-18,  2.49958852e-02,
                  2.62562532e-02,  2.64435715e-02,  6.99533302e-02,
                  7.08374623e-02]])},
    rew: array([-1., -1.], dtype=float32),
)

Here, batch.obs['achieved_goal'], batch.obs['desired_goal'], batch.obs['observation'] is what you want.

And in your network side, you can receive the batch formatted as above. E.g.

class NN(nn.Module):
    def forward(self, batch, ...):
        s = batch.obs['observation']
        ...

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

There are still a lot of things to be improved. I'll continue to work on this issue.

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

Now it supports a Batch of Batch:

Batch(
    done: array([False, False]),
    info: array([{'is_success': 0.0}, {'is_success': 0.0}], dtype=object),
    obs: Batch(
             achieved_goal: array([[1.42853749, 0.63666553, 0.42473605],
                                   [1.20106979, 0.77055984, 0.42473605]]),
             desired_goal: array([[1.47228501, 0.88356362, 0.4490597 ],
                                  [1.37898198, 0.71865667, 0.42469975]]),
             observation: array([[ 1.36945404e+00,  7.78820325e-01,  5.64114120e-01,
                                   1.42853749e+00,  6.36665526e-01,  4.24736048e-01,
                                   5.90834503e-02, -1.42154798e-01, -1.39378071e-01,
                                   3.97852140e-02,  4.16259342e-02, -3.85214084e-07,
                                   5.92637053e-07,  1.12208536e-13, -2.49958924e-02,
                                  -2.62562578e-02, -2.64169060e-02,  1.87589293e-07,
                                  -2.88598912e-07,  1.30443021e-18,  2.49958852e-02,
                                   2.62562532e-02,  2.64435715e-02,  6.99533302e-02,
                                   7.08374623e-02],
                                 [ 1.36945404e+00,  7.78820325e-01,  5.64114120e-01,
                                   1.20106979e+00,  7.70559842e-01,  4.24736048e-01,
                                  -1.68384253e-01, -8.26048228e-03, -1.39378071e-01,
                                   3.97852140e-02,  4.16259342e-02, -3.85214084e-07,
                                   5.92637053e-07,  1.12208536e-13, -2.49958924e-02,
                                  -2.62562578e-02, -2.64169060e-02,  1.87589293e-07,
                                  -2.88598912e-07, -3.26321805e-18,  2.49958852e-02,
                                   2.62562532e-02,  2.64435715e-02,  6.99533302e-02,
                                   7.08374623e-02]]),
         ),
    rew: array([-1., -1.], dtype=float32),
)

You can either use batch.obs['observation'] or batch.obs.observation, and also
batch[0].obs.observation,
batch.obs[0].observation and
batch.obs.observation[0] to get a slice from this observation. The last one is better because it needs the minimum data copies.

from tianshou.

DrJimFan avatar DrJimFan commented on July 16, 2024

Thank you so much for implementing this!
About Batch of Batch: what's the use case of a nested batch? At least for me, a one-level collated batch is good enough.

from tianshou.

DrJimFan avatar DrJimFan commented on July 16, 2024

I see, right, that's how you get the dot-access syntax. Thanks!

from tianshou.

ArniDagur avatar ArniDagur commented on July 16, 2024

Does Tianshou support Tuple/Dict action spaces as well?

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

Dict space is naturally supported (refer to https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation and https://tianshou.readthedocs.io/en/master/tutorials/batch.html). Tuple space is not recommended, as discussed in #172.

from tianshou.

ArniDagur avatar ArniDagur commented on July 16, 2024

In the link you provided, there's an example of Dict observation space, but what about Dict action spaces? Or perhaps I just missed something on the page?

It would be really cool if you'd add an example of how Dict action spaces can be used, including e.g. what the output of nn.Module.forward() should be.

from tianshou.

hanghoo avatar hanghoo commented on July 16, 2024

Hi there. May I ask what the meaning of Dict actions here is? {'ar', 'as', 'ap'} means the network has 3 outputs, or there are 3 networks?

from tianshou.

hanghoo avatar hanghoo commented on July 16, 2024

Thanks for your prompt response.
I met an issue attached. In fact, I have a customized Env, and the action space is
spaces.Dict({"act_indicator": spaces.Discrete(), "act_txpower": spaces.Box(low, high)})
I hope the output of the network is a continuous scalar (data rate). Then, I connect a mapping function to map the data rate to "act_indicator" and "act_txpower". Thus, I inset the mapping function into the forward() of ddpg. The error seems to incur by the scalar "act_indicator". However, I can't convert the scalar to np.array successfully.

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

The output batch from policy.forward is a "batch" -- the first dimension is always batch size in every element of batch. indicator should have shape (bsz,) with dtype=int64, and txpower should have shape (bsz, 1) with dtype=float32.

from tianshou.

hanghoo avatar hanghoo commented on July 16, 2024

Right. I have checked the buf.act and I think they satisfy what you mentioned. When I call policy.learn(buf, batch_size=1, repeat=1), the issue is still here.

TypeError: Object 2 in Batch(
    act_indicator: 2,
    act_txpower: array([0.00846217], dtype=float32),
) has no len()

from tianshou.

Trinkle23897 avatar Trinkle23897 commented on July 16, 2024

So could you please paste a minimal example to demonstrate this error here?
BTW, it's better to open another issue.

from tianshou.

hanghoo avatar hanghoo commented on July 16, 2024

So could you please paste a minimal example to demonstrate this error here? BTW, it's better to open another issue.

Thanks for your time. Sorry, I think I may need more time since another issue has appeared.
Okay, I will do that.

from tianshou.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.