Git Product home page Git Product logo

Comments (4)

yiwiz-sai avatar yiwiz-sai commented on August 17, 2024

I changed code, seems it works now , please correct me if I misunderstanding, thank you~

# pip3 install torchrl
from torchrl.modules import MaskedCategorical

class MyDiscreteSACPolicy(DiscreteSACPolicy):
    def forward(  # type: ignore
            self,
            batch: Batch,
            state: Optional[Union[dict, Batch, np.ndarray]] = None,
            input: str = "obs",
            **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        if hasattr(batch.info, 'mask'):
            action_mask = getattr(batch.info, 'mask')
            logits, hidden = self.actor(obs, state=state, info=batch.info)
            mask = torch.tensor(action_mask)
            dist = MaskedCategorical(logits=logits, mask=mask)
        else:
            # mask = None
            logits, hidden = self.actor(obs, state=state, info=batch.info)
            dist = Categorical(logits=logits)

        if self._deterministic_eval and not self.training:
            act = dist.logits.argmax(axis=-1)
            # act = logits.argmax(axis=-1)
            # print(logits, mask, act)
        else:
            act = dist.sample()
     return Batch(act=act, state=hidden, dist=dist)
	 
	 
class MyEnv:
    def reset(self, seed=None, options=None):
        ...
        info = {'mask': action_mask}
        obs = state_arr
        return obs, info

    def step(self, action):
	...
	info = {'mask': action_mask}
        obs = state_arr
        return obs, reward, terminated, truncated, info

from tianshou.

MischaPanch avatar MischaPanch commented on August 17, 2024

Is your question resolved by the corrected code?

from tianshou.

yiwiz-sai avatar yiwiz-sai commented on August 17, 2024

Is your question resolved by the corrected code?

it can work for "action mask", but I am not sure if it will impact on RL, I am not a RL expert.

from tianshou.

MischaPanch avatar MischaPanch commented on August 17, 2024

I'll take it as yes, in general a tutorial on action masking would make sense. I'll open a separate issue on that

from tianshou.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.