Git Product home page Git Product logo

brax's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

brax's Issues

Passive rigid objects?

Hi,

I noticed that the config parser does not like any of the bodies in the protobuf environment definition missing joints or actuators. However, I would like to spawn physically-based spheres in the environment that can be interacted with the character but does not have joints of their own. Is this currently supported?

What I tried: I created a ball object with this protobuf string:

bodies {
  name: "$ BallBody"
  colliders {
    sphere {
      radius: 1
    }
  }
  inertia { x: 1.0 y: 1.0 z: 1.0 }
  mass: 10
}

bodies {
  name: "Aux 1"
  colliders {
    sphere {
      radius: 1
    }
  }
  inertia { x: 1.0 y: 1.0 z: 1.0 }
  mass: 0.001
}

bodies {
  name: "Ground"
  colliders {
    plane {}
  }
  inertia { x: 1.0 y: 1.0 z: 1.0 }
  mass: 1
  frozen { all: true }
}

joints {
  name: "$ BallBody_Aux 1"
  parent_offset { x: 0 y: 0 }
  child_offset { x: 0 y: 0 }
  parent: "$ BallBody"
  child: "Aux 1"
  rotation { x: 0 y: 0 z: 0 }
  angle_limit { min: -10 max: 10 }
}

actuators {
  name: "$ BallBody_Aux 1"
  joint: "$ BallBody_Aux 1"
  strength: 300.0
  torque {}
}

friction: 0.6
gravity { z: -9.8 }
angular_damping: -0.05
baumgarte_erp: 0.1

collide_include {
  first: "$ BallBody"
  second: "Ground"
}

dt: 0.0167
substeps: 10

It does spawn a ball, albeit with a weird initial position ([0, 0, 1], and then falls down). However, the joint is limited between -10 and 10, which doesn't make it a truly free object.

Any thoughts would be appreciated. Thanks!

URDF

Hi, any plans for urdf support (converter)? Regards

Confusion about rotation

Hi~
Is there an easy way to rotate all the rigid bodies in brax? For example, the humanoid object is along the positive x-axis by default, and I want it to be along the y-direction, how can I do that? I tried to rotate all the joints 90 degrees in the opposite direction, but the visualization results were very weird:(

Impulse vs XPBD; and assorted thoughts

Hi, just found this project. Don't really have a concrete issue, but some design questions id like to discuss.

Ive recently started work on a similar project in JAX, that tries to use XPBD rather than impulse based joints. Its too early to tell how thats going to work out exactly; keeping in mind things like the time-of-impact issues raised in the difftaichi paper. Not that this is a concern for my specific application I have in mind, but I think it does point out a more general issue, that the details of the simulator matter to the effective differentiability. So we will see how that works out; hopefully I will know in a week or two.

The reason I went with this particular approach rather than spring or impulse based is that it should provide quite high stiffness yet unconditionally stable simulations, while at the same time not being restricted to infinitely stiff joints but allowing for compliance where required, and at the same time consisting of a rather simple implementation using an explicit integrator that maps well to JAX; and to differentiability in general I think. The only tension is that the XPBD authors recommend gauss-seidel relaxation as converging faster than jacobi, whereas in JAX id rather have the increased parallelism of a jacobi iteration for solving the constraints. I suppose some kind of graph-partitioned block-gauss-seidel would be optimal on an accelerator. Though for simulations with a handful of constraints it might not matter. One thing to optimize along the way I suppose.

Is XPBD a simulation paradigm you have considered, and do you have any thoughts on it? Perhaps you have already tried and found it wanting in some way, in which case I can save myself the trouble. Just noticed the second author of the above paper is also the author of the recent GradSim. They make no mention of the XPBD paradigm though; given that they write their simulator as handwritten cuda kernels wrapped in pytorch its also a very different approach from Brax in general. They seem to treat collision resolution in a pure springlike manner, fully resolving the contact dynamics, which is guess isnt great for timestep size but if you are biting the bullet of using an implicit solver anyway I suppose it is very clean in terms of differentiability. No code available yet so kinda hard to tell what is going on exactly.

Given that there is no shortage of physics simulation paradigms, it might be useful to generalize the physics integrator; make it swappable within brax, and be able to compare their merits for various applications in an easy manner. Come to think of it it probably makes sense to port my code to be a fork of brax to save myself from reinventing a few wheels (overall design is similar anyway), and perhaps if I manage to pull off that level of generality in a backwards-compatible manner, I could make a PR here, if there is any interest in that?

You mention in your paper that a simple optimization differentiating through the simulation does not result in trainable walking policies. I wonder what the crucial difference with difftaichi's walker is, or the one in GradSim. Could it be some subtlety relating to time-of-impact collision dynamics like in difftaichi, or some other subtlety of the simulation? Again, I think a swappable physics backend would be useful here I think; implementing a simple pure mass-spring explicitly integrated non-rigid-body engine should be trivial (havnt done anything with sparse matrices in JAX/TPU before; perhaps the scatter/gather approach you use for joints would also scale fine to a large collection of springs? should be fine for academic purposes in any case); although the interfaces need to be sufficiently generic given that youd have a somewhat different datastructure. But do you think that level of generality is feasible, and we can just plug in a mass-spring backend, while retaining a reasonable degree of backwards compatibility? Havnt looked at the Brax code enough yet to tell.

to_tf_model returns: NotEncodableError

hi,
when i want to save it to a tf model with:

from brax.io import export
export.to_tf_model( "/content/model/", inference_fn, params, state.obs, state.rng)

The error then is:

NotEncodableError                         Traceback (most recent call last)

<ipython-input-26-74bd7d32a550> in <module>()
      1 get_ipython().system('pip install tensorflow')
      2 from brax.io import export
----> 3 export.to_tf_model( "/content/model/", jit_inference_fn, params, state.obs, state.rng)

14 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py in _map_structure(self, pyobj, coders)
     84         return do(pyobj, recursion_fn)
     85     raise NotEncodableError(
---> 86         "No encoder for object [%s] of type [%s]." % (str(pyobj), type(pyobj)))
     87 
     88   def encode_structure(self, nested_structure):

NotEncodableError: No encoder for object [FrozenDict({
    params: {
        hidden_0: {
            bias: TensorSpec(shape=(32,), dtype=tf.float32, name='args_0/1/params/hidden_0/bias'),
            kernel: TensorSpec(shape=(15, 32), dtype=tf.float32, name='args_0/1/params/hidden_0/kernel'),
        },
        hidden_1: {
            bias: TensorSpec(shape=(32,), dtype=tf.float32, name='args_0/1/params/hidden_1/bias'),
            kernel: TensorSpec(shape=(32, 32), dtype=tf.float32, name='args_0/1/params/hidden_1/kernel'),
        },
        hidden_2: {
            bias: TensorSpec(shape=(32,), dtype=tf.float32, name='args_0/1/params/hidden_2/bias'),
            kernel: TensorSpec(shape=(32, 32), dtype=tf.float32, name='args_0/1/params/hidden_2/kernel'),
        },
        hidden_3: {
            bias: TensorSpec(shape=(32,), dtype=tf.float32, name='args_0/1/params/hidden_3/bias'),
            kernel: TensorSpec(shape=(32, 32), dtype=tf.float32, name='args_0/1/params/hidden_3/kernel'),
        },
        hidden_4: {
            bias: TensorSpec(shape=(14,), dtype=tf.float32, name='args_0/1/params/hidden_4/bias'),
            kernel: TensorSpec(shape=(32, 14), dtype=tf.float32, name='args_0/1/params/hidden_4/kernel'),
        },
    },
})] of type [<class 'flax.core.frozen_dict.FrozenDict'>].

What could be wrong ?

thanks

Add gyroscopic term to physics integrator

Looking at the code in integrator.py, I did not find the right equation of evolution of the angular velocity (ang in the code)

According to Featherstone (and a derivation of mine and my university's textbook), the correct equations of motions should be :

With tau being the external torques, omega the angular velocity of the body, and I the inertia matrix in the local frame.

I guess is does not make much of a difference when most degrees of freedom are constrained, but clean math should be better.

Any plans to include this version of the equations in Brax ? And if no, can you give me some pointers as to how one might go and implement that ?

Differentiating n-link arm produces NaNs

I took the double pendulum from the basic tutorial as a starting point for an n-link, planar robot arm. I made the amount of links customisable and added torque-based actuators to the links.

This worked well for 2 and 3 links, but if I go to 4, calculating the gradient wrt the actuation nans.

Here is a colab notebook demonstrating the problem. The code can also be found below for completeness. At the very bottom there is also the config.

At the moment, I am not sure if it is my limited understanding of dynamics or if there is a bug in brax, jax or even xla. Someone looking at it would greatly help.

On a side note, the compilation times also appear quite long. I have tested the case on CPU only so far, though.

Edit: Fixed two tiny bugs in the script, message stays the same.

import functools

import brax
import jax
import jax.numpy as jnp


def make_nlink_arm_system(n_links=2):
    arm = brax.Config(dt=0.01, substeps=5)

    body_names = [
        "anchor",
        *[f"middle-{i}" for i in range(n_links - 1)],
        "end",
    ]

    bodies = []
    for body_name in body_names:
        body = arm.bodies.add(name=body_name, mass=1.0)
        bodies.append(body)
        body.inertia.x, body.inertia.y, body.inertia.z = 1, 1, 1
    bodies[0].frozen.all = True

    for body in bodies[1:-1]:
        cap = body.colliders.add().capsule
        cap.radius, cap.length = 0.5, 1

    for i, (parent, child) in enumerate(zip(body_names[:-1], body_names[1:])):
        joint = arm.joints.add(
            name=f"joint-{i}",
            parent=parent,
            child=child,
            stiffness=10000,
            angular_damping=10,
        )
        joint.angle_limit.add(min=-180, max=180)
        joint.child_offset.z = 1.5
        joint.rotation.z = 90

        arm.actuators.add(
            name=f"actuator-{i}", joint=f"joint-{i}", strength=100
        ).torque.SetInParent()

    return brax.System(arm)


@functools.partial(jax.jit, static_argnums=(0,))
def perform_rollout(f_step, initial_state, plan):
    def f(state, control):
        next_state = f_step(state, control)[0]
        return next_state, next_state

    _, states = jax.lax.scan(f, initial_state, plan)

    return states


def nans(sys, n_steps):
    initial_state = sys.default_qp()
    plan = jnp.ones((n_steps, len(sys.config.actuators)))
    df = jax.grad(lambda *args: perform_rollout(*args).pos.sum(), argnums=2)

    return df(sys.step, initial_state, plan)


if __name__ == "__main__":
    N_STEPS = 16
    sys = make_nlink_arm_system(4)

    res = nans(sys, N_STEPS)   # contains only nans for me
bodies {
  name: "anchor"
  inertia {
    x: 1.0
    y: 1.0
    z: 1.0
  }
  mass: 1.0
  frozen {
    position {
      x: 1.0
      y: 1.0
      z: 1.0
    }
    rotation {
      x: 1.0
      y: 1.0
      z: 1.0
    }
    all: true
  }
}
bodies {
  name: "middle-0"
  colliders {
    capsule {
      radius: 0.5
      length: 1.0
    }
  }
  inertia {
    x: 1.0
    y: 1.0
    z: 1.0
  }
  mass: 1.0
  frozen {
    position {
    }
    rotation {
    }
  }
}
bodies {
  name: "middle-1"
  colliders {
    capsule {
      radius: 0.5
      length: 1.0
    }
  }
  inertia {
    x: 1.0
    y: 1.0
    z: 1.0
  }
  mass: 1.0
  frozen {
    position {
    }
    rotation {
    }
  }
}
bodies {
  name: "middle-2"
  colliders {
    capsule {
      radius: 0.5
      length: 1.0
    }
  }
  inertia {
    x: 1.0
    y: 1.0
    z: 1.0
  }
  mass: 1.0
  frozen {
    position {
    }
    rotation {
    }
  }
}
bodies {
  name: "end"
  inertia {
    x: 1.0
    y: 1.0
    z: 1.0
  }
  mass: 1.0
  frozen {
    position {
    }
    rotation {
    }
  }
}
joints {
  name: "joint-0"
  stiffness: 10000.0
  parent: "anchor"
  child: "middle-0"
  child_offset {
    z: 1.5
  }
  rotation {
    z: 90.0
  }
  angular_damping: 10.0
  angle_limit {
    min: -180.0
    max: 180.0
  }
}
joints {
  name: "joint-1"
  stiffness: 10000.0
  parent: "middle-0"
  child: "middle-1"
  child_offset {
    z: 1.5
  }
  rotation {
    z: 90.0
  }
  angular_damping: 10.0
  angle_limit {
    min: -180.0
    max: 180.0
  }
}
joints {
  name: "joint-2"
  stiffness: 10000.0
  parent: "middle-1"
  child: "middle-2"
  child_offset {
    z: 1.5
  }
  rotation {
    z: 90.0
  }
  angular_damping: 10.0
  angle_limit {
    min: -180.0
    max: 180.0
  }
}
joints {
  name: "joint-3"
  stiffness: 10000.0
  parent: "middle-2"
  child: "end"
  child_offset {
    z: 1.5
  }
  rotation {
    z: 90.0
  }
  angular_damping: 10.0
  angle_limit {
    min: -180.0
    max: 180.0
  }
}
actuators {
  name: "actuator-0"
  joint: "joint-0"
  strength: 100.0
  torque {
  }
}
actuators {
  name: "actuator-1"
  joint: "joint-1"
  strength: 100.0
  torque {
  }
}
actuators {
  name: "actuator-2"
  joint: "joint-2"
  strength: 100.0
  torque {
  }
}
actuators {
  name: "actuator-3"
  joint: "joint-3"
  strength: 100.0
  torque {
  }
}
dt: 0.009999999776482582
substeps: 5
frozen {
}

Gym API?

Hey, noob question here:

Is there an example showing how to interface with brax the same way you would with a gym Env?

If not, then why? Is there any particular reason why you make your own Env class rather than create a new gym.Env subclass?

Applying external forces

Hi,

I think Brax currently doesn't support applying external force/torque to the bodies. I'm getting around this by setting linear and angular velocities with qp.replace(). It would be nice to have an explicit function like PyBullet's applyExternalForce. Any plan to implement something like this?

Differentiation wrt system parameters

Hey,

I can see how brax can be used to differentiate with respect to the system state. I wonder if there is a nice way to also diff wrt, e.g., the mass of a body. Taking the basic tutorial as an example, I would like to have something like step(sys, ..., ball_mass=10.) which is a pure jax function.

I have found brax.physics.bodies.Body, which could be adapted via a .replace() call. However, I don't see how I can a) find it starting out with a System instance and b) how I could update that instance with a replaced version without potentially breaking things.

Unstable / weird ball rolling and colliding behaviour

Hi,

For context, I'm working on making a billiard simulator with Brax. I'll need good rolling / collision simulation for that. However, I'm observing strange behaviours when it comes to having two capsules roll on the ground (plane) and running into each other. See this video for example:

https://www.dropbox.com/s/zhppu5sksben1ex/Screencast%202021-09-13%2014%3A06%3A51.mp4?dl=0

The rolling ball seems to be spinning in place sometimes, and the collision doesn't seem to transfer the momentum from the first ball to the other. Is this due to something not being implemented yet, or do I have a wrong configuration somewhere? This is my config: https://pastebin.com/QPhVD8vS

Thanks.

endless visualization of trajectory

hi,
is it possible to add a feature to display endlessly the trajectory of the learned inference function ?

At the moment it seemed to be fixed to "20" , i think it is "seconds" ?

thx

Inverse dynamics?

Brax (obviously) supports forward dynamics, but I was wondering if it has any support for inverse dynamics as well? I'm looking for something akin to MuJoCo's http://www.mujoco.org/book/APIreference.html#mj_inverse. Looking through the paper and code, I have not been able to find anything yet.

Assuming this is not yet implemented, what do you think it would it entail to implement it?

Training notebook not working due to JAX updates

JAX pushed a new updates yesterday after which the Brax training script doesn't work anymore with the following message. I am not sure if the Brax side could help resolve the issue or we need help from JAX team.

Screenshot 2021-09-24 at 11 54 13 AM

Rendering to the HTML doesn't work

When I try to play saved html visualization on my local computer I always get this error in HTML:
Uncaught TypeError: Cannot read property '0' of undefined
at demo.html:191
at Array.map ()
at demo.html:191
at Array.forEach ()
at createAnimationClip (demo.html:189)
at demo.html:314

Everything works fine in the google colab.

physics_test.py fails

Congrats, this is a very nice repository!

Running the physics_test.py has a few failing tests.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\miniconda3\lib\site-packages\absl\testing\parameterized.py", line 316, in bound_param_test
    return test_method(self, *testcase_params)
  File "D:\dev\brax_official\brax\tests\physics_test.py", line 54, in test_pendulum_period
    config.dt = 2 * jnp.pi * jnp.sqrt((.4 * radius**2 + 1.) / 9.8)
  File "D:\miniconda3\lib\site-packages\google\protobuf\internal\python_message.py", line 715, in field_setter
    raise TypeError(
TypeError: Cannot set brax.Config.dt to DeviceArray(2.0086575, dtype=float32): DeviceArray(2.0086575, dtype=float32) has type <class 'jaxlib.xla_extension.DeviceArray'>, but expected one of: numbers.Real

Using a hard coded config.dt = 1.0066762 makes the test pass

This is on Windows 10, using latest JAX master on CUDA 11.0, cudnn 8.1.1 (but same for older JAX 0.1.61).

By the way, had some fun training the Ant using a local runtime on Windows, attached are the timings (RTX 2080)
also created some JAX precompiled wheels of latest master branch for other poor Windows souls, see
https://github.com/erwincoumans/jax/releases/tag/jax-v0.1.68_windows
ant_trained

Demo request: optimizing for mass/system identification

Dear Brax team,

Since Brax is fully differentiable, I thought it'd be possible to use it like DiffTaichi or GradSim for system identification (e.g. determining the mass of an object from a trajectory and known force) but I couldn't find any example for this.
Do you happen to have any demo or tips for this?

From the top of my head I would do something like this:
Let's say the task is to estimate the mass of a cube that received a push. The size of the cube is known, same as the friction coefficient, and the applied force.

  1. Record a rollout of the positions of that cube after the push.
  2. Reset the cube to its starting position and set the mass to a random value.
  3. (a) generating a rollout, (b) measuring the MSE between new observed positions over time and GT positions, (c) calculate the gradient wrt to the mass property of the cube and applying that to the mass.
  4. Repeat (3) until the loss increases.

Best,
Florian

Config format, Denavit-Hartenberg, UR5e

Hey,

I am currently playing a bit with implementing robot arms that I have given by the Denavit-Hartenberg convention accd to Craig. My plan was to write a function that takes such a sequence of joint specifications and then puts together a config where the corresponding arm is implemented and everything nicely connected by capsules.

I got a little confused along the road, as I started out with the official UR5e DH specification (found on this page) and tried to match it to the config of the UR5e environment. Specifically, "joint2" in the UR DH parameters has an "a" value of -0.425, while the second joint in the config has a parent_offset.y = 13.8. From my view, these values should be the same, up to units.

Another potential explanation is me not understanding how Brax uses configs to build kinematic chains.

A thing that I struggle with is how a joint's axis of rotation maps to the transformation from the two connected bodies' respective frames. In my experience, these rotations are often given as the angle between the the z axes of the two consecutive frames around the x-axis of the first frame.

I am sorry that I don't have very clear questions that allow yes/no answers, but at this point I struggle quite a bit with finding how brax relates two links through a joint. I'd very much appreciate a brief explanation of this.

Manually setting pose

Is it currently possible to set object poses manually? I'm trying to use qp.replace() within env.reset() but whatever change I make seems to be ignored. For example, I'm trying to make the 0th object 1 meter higher:

    def reset(self, rng: jnp.ndarray) -> env.State:
        """Resets the environment to an initial state."""
        qp = self.sys.default_qp()
        info = self.sys.info(qp)
        reward, done, steps = jnp.zeros(3)
        metrics = {}

        jax.ops.index_add(qp.pos, jax.ops.index[0, :], jnp.array([0, 0, 1]))
        qp, info = self.sys.step(qp, jnp.zeros(self.action_size).reshape(1, -1))
        obs = self._get_obs(qp, info)
        new_state = env.State(rng, qp, info, obs, reward, done, steps, metrics)
        return new_state

I suspected something inside system.py might support this, but I couldn't find it.

bug with "mass"

hi,
this config

_SYSTEM_CONFIG = """
bodies {
  name: "box_1"
  colliders {
    box {
      halfsize {
          x: 0.5
          y: 0.5
          z: 0.5
        }
    }
  }
  inertia {
    x: 1
    y: 1
    z: 1
  }
  mass: 0.03
}

bodies {
  name: "box_2"
  colliders {
    box {
      halfsize {
          x: 0.25
          y: 0.25
          z: 0.25
        }
    }
  }
  inertia {
    x: 1
    y: 1
    z: 1
  }
  mass: 0.03
}
bodies {
  name: "Ground"
  colliders {
    plane {
    }
  }
  frozen {
    all: true
  }
}
bodies {
  name: "target"
  colliders {
    sphere {
      radius: 0.009
    }
  }
  frozen { all: true }
}

joints {
  name: "joint0"
  angle_limit {
      min: -60
      max: 60
  }
  rotation {
    z: 90
  }
  
  parent_offset {
    y: -0.6
  }
  child_offset {
      z: 0.1
  }
  
  parent: "box_1"
  child: "box_2"
  stiffness: 5000.0
  angular_damping: 35
}

actuators {
  name: "joint0"
  joint: "joint0"
  strength: 300.0
  torque {
  }
}

friction: 0.6
gravity {
  z: -9.81
}
baumgarte_erp: 0.1
angular_damping: -0.05

collide_include {
  first: "box_1"
  second: "Ground"
}
collide_include {
  first: "box_2"
  second: "Ground"
}


dt: 0.02
substeps: 4
"""

you can see the two boxes with training.ipynb with html.render, but after training if you visualize a trajectory, there is only a grey image to see.

If you stop and reset the trajectory video, you see in the controls the x,y,z, etc. values which are going fastly into minus till NaN, if you play the video?

If you change the "mass" values of both boxes to 1.0 , everything looks ok ?

Performance nitpick

def inv_quat(q):

Little nitpick but

S = jnp.array([1., -1., -1., -1.])
def inv_quat(q):
	return q * S

Benchmarks as 10% faster on my laptop cpu at least; and I suspect the same would be more true of architectures more aggressively tuned for vectorization. Dont have any experience with TPUs and their compilers, but this formulation would also make it easier for a GPU compiler to get to the GPU-optimal compiled code I imagine.

Effect of batch_shape argument in Envs.create_env

Hello, working with version 0.0.3. I have a question that concerns the batch_shape argument; basically how does it work exactly?

Regardless of the batch_size arg, it seems like it just looks at the first dim of the PRNGKey. For example, with a single key it gives a state batch shape of 2:

import jax
import brax

from brax import envs

rng = jax.random.PRNGKey(0)

env = envs.create(env_name='halfcheetah', batch_size=100)
state = env.reset(rng=rng)

print(rng.shape)
print(state.obs.shape)

> (2,
> (2, 23)

If i instead split a key in 50 i get the following:

import jax
import brax

from brax import envs

rng = jax.random.split(jax.random.PRNGKey(0), 50)

env = envs.create(env_name='halfcheetah', batch_size=100)
state = env.reset(rng=rng)

print(rng.shape)
print(state.obs.shape)

> (50,)
> (50, 23)

In both cases the batch_shape argument had no effect. How should I properly create multiple environments on a single device?

Replacing gym's Mujoco envs with brax envs

Had a conversation with @jkterry1 on openai/gym#2366, and it appears brax would also be a great alternative for the mujoco envs replacement.

To help with this transition. I made an attempt to try out brax with pytorch. Here is a basic report: https://wandb.ai/costa-huang/brax/reports/Brax-as-Pybullet-replacement--Vmlldzo5ODI4MDk. The source code is here: https://github.com/vwxyzjn/cleanrl/blob/mybranch/cleanrl/brax/readme.md

One of the biggest issue with the brax adoption is the env normalization:

I think going forward, probably the best way to fix this is to refactor the brax training side's normalization to the environment side. This in the future will also help throughput with the JaxToTorchWrapper. Otherwise, the observation will go from GPU to CPU for gym or sb3's normalization wrapper, then GPU again for torch, which just doesn't make sense.

One small thing is that given the brax environment directly produces the vector env, there is also no way to inject a ClipActionsWrapper(env), which may or may not have a performance impact. That said, this can be implemented in the training side with ease.

autoreset batch environments when done=True

Hi, and thanks for this amazing work.

I am wondering if there is a way to automatically reset an environment when reaching a terminal state when executing a batch of environments. It is unclear to me how to do this with BRAX when episodes can be of varying lengths.

`learn` not working on Google Cloud TPU

On a google cloud tpu vm, I installed jax with
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html, then cloned this repo and installed with pip install -e .

running learn throws the below error:


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mandi_zhao/miniconda3/envs/brax/bin/learn", line 7, in <module>
    exec(compile(f.read(), __file__, 'exec'))
  File "/home/mandi_zhao/brax/bin/learn", line 7, in <module>
    app.run(learner.main)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/mandi_zhao/brax/brax/training/learner.py", line 176, in main
    progress_fn=writer.write_scalars)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 465, in train
    (training_state, state), losses = minimize_loop(training_state, state)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 402, in _minimize_loop
    length=num_epochs // log_frequency)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 387, in run_epoch
    length=num_update_epochs)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 361, in minimize_epoch
    length=num_minibatches)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 338, in update_model
    loss_grad, metrics = grad_loss(params, data, key_loss)
  File "/home/mandi_zhao/brax/brax/training/ppo.py", line 150, in compute_ppo_loss
    policy_logits, data.actions)
  File "/home/mandi_zhao/brax/brax/training/distribution.py", line 77, in log_prob
    log_probs = dist.log_prob(actions)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1315, in log_prob
    return self._call_log_prob(value, name, **kwargs)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1297, in _call_log_prob
    return self._log_prob(value, **kwargs)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/tensorflow_probability/substrates/jax/distributions/normal.py", line 190, in _log_prob
    x / scale, self.loc / scale)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 6553, in deferring_binary_op
    return binary_op(self, other)
  File "/home/mandi_zhao/miniconda3/envs/brax/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 839, in true_divide
    return lax.div(x1, x2)
TypeError: div got incompatible shapes for broadcasting: (30, 87), (30, 8).

Friction per body?

Hi,

I'm trying to create multiple bodies in the environment with differing friction coefficients. However, I couldn't find anywhere how to do this. It seems that most examples have a global friction coefficient of 0.6 but no difference across bodies.

I'd appreciate any thoughts on this. Thanks.

Support for height maps and factor of 2.0 for grip

Hi,

After close inspection of the collider code, I found this interesting piece of code in brax/physics/colliders.py:

  # factor of 2.0 here empirically helps object grip
  # TODO: expose friction physics parameters in config
  return dp_n * colliding_n + dp_d * colliding_d * 2.0

This factor of two is not physical at all and after some tests, is only non-physics-breaking when gravity is collinear with the collision normal.

Can someone give some pointers as to why is was put there, home much more "grip" we get with such an hack and whether or not removing it would break performances of the simulator ?

Thanks.

PS : I am implementing, and almost ready to merge-request, support for height-maps : same functionality as planes but with the added benefit of being able to add terrain altitude. Due to the non-vertical normals of such a terrain, this factor of 2.0 comes in conflict with the height map support.

Multi-Agent Environments

Hello,

Are you planning to create any multi-agent environment such as crowd simulation?

Is there also possibility to have a non-uniform terrain, walls etc in each environment?

so that each agent can be initialized in a random location for varying its experience.

(without that, I don't see a major advantage of parallel simulation capability of engine)

Sincerely,
Kamer

MBRL

Hi, any plans for MBRL example (like PETS)? Regards

wire, rope, force-sensor implementations ?

Hi,
are there plans to integrate wire or rope as a body ? Or does it need to be build with many rigid-bodies and joints ?

Is there a way to measure a force onto a rigid-body (force-sensor) from another rigid body ? I mean, e.g. the humanoid-demo has force-sensors at the feets, to measure with how much force-to-the-ground it runs ? Or at the hands, another way to detect falling.

thanks

Added walls to environment disappear in training

hi,
i build a custom env ( https://github.com/flobotics/brax/blob/main/brax/envs/simtoreal.py ) , which is mostly the reacher with walls around it. If i use html.render() it shows everything, but when i then go on with training and then look at the trajectory output, The walls start to sink into the ground and disappear ? I want to use the walls to let the target bounce off the walls , but if they are not there ? What do i wrong ?

Here a images of the start of the trajectory video output:

image

Then it disappears

image

till its complety gone

image

little documentation about joints

hi,
i tried alot and i dont get it, perhaps its a bug. I dont know where the joint is placed and in which direction it can turn? For testing i use two boxes

image

The code to test is here: https://github.com/flobotics/brax/blob/main/brax/envs/test_two_box_x_2.py

With a joint, that got no position{} values, so i think the joint is in the middle of both boxes, i can only rotate the boxes with rotation { y: -90} no other value like z: 90,y: 0, x:90 work ? Should not every value work ?

What does "rotation" mean ?

I know the example envs like halfcheetah which uses z:90 , but if i try to use that in a own config, it does not work ?

Also it seemed that the values for child_offset and parent_offset simply get added to the position values only of the child, so the joint-origin is always in the middle of the parent , is that correct ?

A little documentation about usage of joints would be nice
thanks

spring behaviour of joints

hi,
i build a crawler ( https://github.com/flobotics/brax/blob/custom_envs/brax/envs/crawler5.py ) but when i train it and then visualize the trajectory, the joints went "forward-backward" like a spring ?

Here is a gif ( https://github.com/flobotics/brax/blob/custom_envs/crawler5-1.gif ) where you see how it is moving.

I tried different settings in joint description (spring_damping,etc) but nothing changed this behaviour ?

What can i do ? Also it seemed that the joints do not rotate the -60/+60 degrees as they could ?

thanks

Rewards Not Zero after Done

Hello,

I have extended the PyTorch example with an Augmented Random Search implementation:
https://github.com/kayuksel/braxars/blob/main/braxars_multi.py

However, what I have noticed that the reward of a batch-member is not zero after being done.
What are the values that are returned for "done" members? Should I treat their reward as zero?
I am now resetting the environment when that happens, I couldn't find how to reset done-members.

Another question I have is on the rendering. Are we able to render while using a notebook only?
Is it possible to render a selected (e.g. best-performing) batch member or all in the same render?

Changing `const system` declaration in `html.py` to `var system`?

Hello from the Princeton office! We are excited about brax and have started playing around with it.

One snag in our workflow: the generated JavaScript for html.py:render declares system as a const, so re-evaluating an Jupyter notebook cell calling html.render will fail. We've forked and changed the const to var, but hesitated submitting a PR since there may have been some important reason we missed for choosing const.

Anyway, just wanted to check!

Analytical policy gradient demo

Hi, Brax developers, thanks for developing Brax. I saw that analytical policy gradient training function (apg.py) has bee added and updated in the past few days. I have tried to use it to train in some of the given envs (e.g. ant), but it doesn't give me the results expected. I wonder if you could provide a demo setting for using apg in some of the envs.

support for new environments

Hi,

Thanks for the awesome work, I really think this is a game-changer!
Do you have any potential robotics manipulation environments on your roadmap, e.g. Meta-World?
I think it would be a great service to the community (including me!) to add one of those.
Maybe it's a daunting task, I can't tell.

Thanks anyways,
Massimo Caccia

html renderer view problems viewing millimeter bodies

hi,
when using HTML(html.render(sys, qps)) with bodies of size 0.0045 and so , the renderer view is not able to zoom to this little bodies. It starts showing the space-under-the-plane before you can really see the bodies.

The bodies are there, but the view is "too big" ?

ValueError: Incompatible shapes for broadcasting

The error occurs while running
inference_fn, params, _ = train_fn(environment_fn=env_fn, progress_fn=progress)

For testing:
The colab-notebook is here https://github.com/flobotics/brax/blob/main/notebooks/crawler-colab.ipynb

And the env is here https://github.com/flobotics/brax/blob/main/brax/envs/crawler.py

---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

<ipython-input-6-0369bdad0cb4> in <module>()
     82 
---> 83 inference_fn, params, _ = train_fn(environment_fn=env_fn, progress_fn=progress)
     84 

101 frames

UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: ((1, 50, 32), (50, 32, 4))

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _promote_shapes(fun_name, *args)
    245       if config.jax_numpy_rank_promotion != "allow":
    246         _rank_promotion_warning_or_error(fun_name, shapes)
--> 247       result_rank = len(lax.broadcast_shapes(*shapes))
    248       return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
    249               for arg, shp in zip(args, shapes)]

ValueError: Incompatible shapes for broadcasting: ((1, 50, 32), (50, 32, 4))

how to get boxes to collide / rest on top of each other

Hi,

I'm having some trouble understanding how to correctly specify colliders. In the test below, I'm placing 1m boxes at z positions 1 and 3, and expect them to fall to z positions .5 and 1.5 (with Box2 resting on top of Box1), but Box2 passes through Box1. Is there something more I should specify to make the boxes collide?

import brax
from google.protobuf import text_format
from jax import numpy as jnp

_CONFIG = """
  dt: 1.5 substeps: 1000 friction: 0.6 baumgarte_erp: .1
  gravity { z: -9.8 }    
  bodies { name: "Ground" frozen: { all: true } colliders { plane {}}}
  bodies {    
    name: "Box1" mass: 1
    colliders { box { halfsize { x: 0.5 y: 0.5 z: 0.5 }}}
    inertia { x: 1 y: 1 z: 1 }
  }
  bodies {
    name: "Box2" mass: 1
    colliders { box { halfsize { x: 0.5 y: 0.5 z: 0.5 }}}
    inertia { x: 1 y: 1 z: 1 }
  }
  
"""

def test_boxes_fall_on_one_another():
  """Box1 falls onto the ground and stops, Box2 falls and rests on top of Box1."""
  sys = brax.System(text_format.Parse(_CONFIG, brax.Config()))
  qp = brax.QP(
      pos=jnp.array([[0, 0, 0],[0., 0., 1.],[0., 0., 3.]]),
      rot=jnp.array([[1., 0., 0., 0.], [1., 0., 0., 0.], [1., 0., 0., 0.]]),
      vel=jnp.array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]),
      ang=jnp.array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
  qp, _ = sys.step(qp, jnp.array([]))  
  return sys,qp

sys,qp = test_boxes_fall_on_one_another()
print(qp.pos)

# observed
#[[0.         0.         0.        ]
# [0.         0.         0.49988964]
# [0.         0.         0.49988964]]

# expected
#[[0.         0.         0.        ]
# [0.         0.         0.5]
# [0.         0.         1.5]]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.