Git Product home page Git Product logo

Comments (14)

blumu avatar blumu commented on May 23, 2024 1

@forrestmckee That's a good suggestion, it's not currently planned but it feels like something we ought to have. Work on stable_baselines3 is currently postponed until they complete the upgrade to the latest version of gym. Once this is done the plan is to upgrade cyberbattlesim to the latest version of gym and stable_baselines which will then allow for further improvements like the one you mentioned.

from cyberbattlesim.

forrestmckee avatar forrestmckee commented on May 23, 2024 1

@Screamer-Y the difference between the built in algorithms and Stable-Baselines3 is that the built ins have a check to ensure a valid action. SB3 doesn't, so a large portion of the time you're performing an impossible action given the current state of the environment.

from cyberbattlesim.

2twentytwo2 avatar 2twentytwo2 commented on May 23, 2024

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

Hi everyone,
I noticed that PR(#86) already has support for stable-baselines3 and gives the corresponding use cases in the notebooks folder, which is of greate help. Sorry for missing that before...

from cyberbattlesim.

kvas7andy avatar kvas7andy commented on May 23, 2024

Hi, @Screamer-Y, actually that is a interesting topic for me to investigate. If you will have any problems or you find any literature using this simulator, feel free to share, and I can do the same if you are interested.

Edited: I found the source of the error it was purely because of my own additions to the simulator.

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

Hi, @Screamer-Y, actually that is a interesting topic for me to investigate. If you will have any problems or you find any literature using this simulator, feel free to share, and I can do the same if you are interested.

Edited: I found the source of the error it was purely because of my own additions to the simulator.

Hi, @kvas7andy , it's glad to know you're interested in this topic too.
Now I'm working on https://github.com/microsoft/CyberBattleSim/blob/4fd228bccfc2b088d911e27072a923251203cac8/cyberbattle/_env/flatten_wrapper.py.
My goal is to modify the 'action_space' from a 'spaces.MultiDiscrete' to a 'spaces.Discrete'. I do this because in stable-baselines3, if you want to use the value-based RL algorithm, action_space can only be 'spaces.Discrete'. Currently I have simply mapped all the possible actions to a Discrete value, but when a try to train a DQN agent, it can not learn from the environment properly. I'm still trying to figure out what went wrong.
Here is my modification,

class FlattenActionWrapper(ActionWrapper):
    """
    Flatten all nested dictionaries and tuples from the
     action space of a CyberBattleSim environment`CyberBattleEnv`.
     The resulting action space is a `Discrete`.
    """

    def __init__(self, env: CyberBattleEnv):
        ActionWrapper.__init__(self, env)
        self.env = env

        # local:source_node_num x local_attacks_count;remote:source x target x remote_attacks_count;
        # connect:source x target x remote_attacks_count x port x credentials
        self.action_space = spaces.Discrete(env.bounds.maximum_node_count*env.bounds.local_attacks_count + \
            env.bounds.maximum_node_count*env.bounds.maximum_node_count*env.bounds.remote_attacks_count + \
            env.bounds.maximum_node_count*env.bounds.maximum_node_count*env.bounds.port_count*env.bounds.maximum_total_credentials)
        

    def action(self, action: np.int64) -> Action:
        n_nodes = self.env.bounds.maximum_node_count
        n_local_attacks = self.env.bounds.local_attacks_count
        n_remote_attacks = self.env.bounds.remote_attacks_count
        n_port = self.env.bounds.port_count
        n_credentials = self.env.bounds.maximum_total_credentials

        if action<n_nodes*n_local_attacks:
            source_node = action//n_local_attacks
            local_vulnerability = action%n_local_attacks
            return {'local_vulnerability': np.array([source_node, local_vulnerability])}

        action -= n_nodes*n_local_attacks
        if action < n_nodes*n_nodes*n_remote_attacks:
            source_node = action//(n_remote_attacks*n_nodes)
            target_node = (action//n_remote_attacks)%n_nodes
            remote_vulnerability = action%n_remote_attacks
            return {'remote_vulnerability': np.array([source_node, target_node, remote_vulnerability])}

        action -= n_nodes*n_nodes*n_remote_attacks
        if action < n_nodes*n_nodes*n_port*n_credentials:
            source_node = action//(n_nodes*n_port*n_credentials)
            target_node = (action//(n_port*n_credentials))%n_nodes
            port = (action//(n_credentials))%n_port
            credential = action%n_credentials
            return {'connect': np.array([source_node,target_node,port,credential])}

        raise NotSupportedError(f'Unsupported action: {action}')

    def reverse_action(self, action):
        raise NotImplementedError

I'm not a good programmer, so feel free to point out any problem and I will appreciate it.

from cyberbattlesim.

forrestmckee avatar forrestmckee commented on May 23, 2024

Hi, @Screamer-Y, actually that is a interesting topic for me to investigate. If you will have any problems or you find any literature using this simulator, feel free to share, and I can do the same if you are interested.
Edited: I found the source of the error it was purely because of my own additions to the simulator.

Hi, @kvas7andy , it's glad to know you're interested in this topic too. Now I'm working on https://github.com/microsoft/CyberBattleSim/blob/4fd228bccfc2b088d911e27072a923251203cac8/cyberbattle/_env/flatten_wrapper.py. My goal is to modify the 'action_space' from a 'spaces.MultiDiscrete' to a 'spaces.Discrete'. I do this because in stable-baselines3, if you want to use the value-based RL algorithm, action_space can only be 'spaces.Discrete'. Currently I have simply mapped all the possible actions to a Discrete value, but when a try to train a DQN agent, it can not learn from the environment properly. I'm still trying to figure out what went wrong. Here is my modification,

class FlattenActionWrapper(ActionWrapper):
    """
    Flatten all nested dictionaries and tuples from the
     action space of a CyberBattleSim environment`CyberBattleEnv`.
     The resulting action space is a `Discrete`.
    """

    def __init__(self, env: CyberBattleEnv):
        ActionWrapper.__init__(self, env)
        self.env = env

        # local:source_node_num x local_attacks_count;remote:source x target x remote_attacks_count;
        # connect:source x target x remote_attacks_count x port x credentials
        self.action_space = spaces.Discrete(env.bounds.maximum_node_count*env.bounds.local_attacks_count + \
            env.bounds.maximum_node_count*env.bounds.maximum_node_count*env.bounds.remote_attacks_count + \
            env.bounds.maximum_node_count*env.bounds.maximum_node_count*env.bounds.port_count*env.bounds.maximum_total_credentials)
        

    def action(self, action: np.int64) -> Action:
        n_nodes = self.env.bounds.maximum_node_count
        n_local_attacks = self.env.bounds.local_attacks_count
        n_remote_attacks = self.env.bounds.remote_attacks_count
        n_port = self.env.bounds.port_count
        n_credentials = self.env.bounds.maximum_total_credentials

        if action<n_nodes*n_local_attacks:
            source_node = action//n_local_attacks
            local_vulnerability = action%n_local_attacks
            return {'local_vulnerability': np.array([source_node, local_vulnerability])}

        action -= n_nodes*n_local_attacks
        if action < n_nodes*n_nodes*n_remote_attacks:
            source_node = action//(n_remote_attacks*n_nodes)
            target_node = (action//n_remote_attacks)%n_nodes
            remote_vulnerability = action%n_remote_attacks
            return {'remote_vulnerability': np.array([source_node, target_node, remote_vulnerability])}

        action -= n_nodes*n_nodes*n_remote_attacks
        if action < n_nodes*n_nodes*n_port*n_credentials:
            source_node = action//(n_nodes*n_port*n_credentials)
            target_node = (action//(n_port*n_credentials))%n_nodes
            port = (action//(n_credentials))%n_port
            credential = action%n_credentials
            return {'connect': np.array([source_node,target_node,port,credential])}

        raise NotSupportedError(f'Unsupported action: {action}')

    def reverse_action(self, action):
        raise NotImplementedError

I'm not a good programmer, so feel free to point out any problem and I will appreciate it.

@Screamer-Y, did you get the stable-baselines example script to work? For me it runs, but the agent never learns anything using A2C or PPO.

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

@Screamer-Y, did you get the stable-baselines example script to work? For me it runs, but the agent never learns anything using A2C or PPO.
Hi @forrestmckee ,
Yes, From my side it works properly, I just run the code in https://github.com/microsoft/CyberBattleSim/blob/main/notebooks/stable-baselines-agent.py without any modification.

from cyberbattlesim.

forrestmckee avatar forrestmckee commented on May 23, 2024

@Screamer-Y are you using Linux, WSL, or Docker?

I can get the script you referenced to run, but the agent never makes it off of the foothold node regardless of the number of time steps I set. I'm also getting warnings that the agent is trying to access an invalid index.

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

@Screamer-Y are you using Linux, WSL, or Docker?

I can get the script you referenced to run, but the agent never makes it off of the foothold node regardless of the number of time steps I set. I'm also getting warnings that the agent is trying to access an invalid index.

Hi @forrestmckee ,
I'm using Ubuntu Server 20.04 LTS. I ran the script again just now and only made one successful connect action with 10000 time steps. I think the problem is due to the way 'action_space' defined in 'flatten_wrapper', which contains all attacks, even if it is invalid and it's also the reason why you keep getting warnings.
I have the same problem when turning the 'action_space' into 'spaces.Discrete', one possible solution is reduce the dims of 'action_space' just as the way in ['agent_wrapper].(https://github.com/microsoft/CyberBattleSim/blob/main/cyberbattle/agents/baseline/agent_wrapper.py)

from cyberbattlesim.

Gabriel0402 avatar Gabriel0402 commented on May 23, 2024

@forrestmckee I came across the same issue as you met. But I noticed another interesting thing: during the training, although we got the warnings that the agent is trying to access an invalid index, but the number of nodes discovered so far is increasing. I think this means that the A2C or PPO is actually working. They did discover new nodes. The thing I don't understand is when the trained model is applied to the action prediction, it never discovers new nodes.

@Screamer-Y I don't quite understand why warnings are incorrect. I think you will also see the warnings if you set the logging levels. Because we have to discover new nodes, whose number is less than the maximum node count. So when the nodes are not discovered, we will always get warnings. And I also don't understand why we have to reduce the dims of action space.

from cyberbattlesim.

forrestmckee avatar forrestmckee commented on May 23, 2024

@blumu Is there a planned sample_valid_action equivalent for Flattened Environments/Stable Baselines3? I believe what myself and others have discovered is that the entire observation and action spaces are "fair game" for the agent to sample from at any given time. Doesn't this mean that an agent can attempt to take an action both to and from a node that it hasn't discovered yet? This seems to greatly increase the number of time steps required for an agent to learn.

@Screamer-Y were you able to reduce the dims of the action space like you mentioned?

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

Thanks for all the suggestions!
@Gabriel0402 I think you are right about the warnings, I didn't have a good understanding of the code at the time.

Regarding the second question, I once expected to speed up the learning process by reducing the size of action_space, and after trying this I found that this did not work significantly.

@forrestmckee @blumu So I still have questions: Is there a significant performance difference between the A2C or PPO methods implemented in
Stable-baseline3 and the DQN method implemented in agent_dql?

With the same set of iteration_count=1500 and episode_count=20, I observed in toy-ctf that A2C only gets an average return of no more than 40 per episode, which is far from the average return of about 450 in benchmark. I would be very grateful if you have any better approaches to improve the performance of Stable-baseline3.

from cyberbattlesim.

Screamer-Y avatar Screamer-Y commented on May 23, 2024

@forrestmckee Thank you so much for the speedy reply. I think I've understood what you've mentioned in this comment and the previous one :)

from cyberbattlesim.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.