Git Product home page Git Product logo

objax's Introduction

Objax

Tutorials | Install | Documentation | Philosophy

This is not an officially supported Google product.

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.

You can find READMEs in the subdirectory of this project, for example:

User installation guide

You install Objax using pip as follows:

pip install --upgrade objax

Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps required to install CUDA-enabled jaxlib (jaxlib releases require CUDA 11.2 or newer):

RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
pip uninstall -y jaxlib
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION

For more installation options, see https://github.com/google/jax#pip-installation-gpu-cuda

Useful environment configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing your installation

You can test your installation by running the code below:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Runing code examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Citing Objax

To cite this repository:

@software{objax2020github,
  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},
}

Developer documentation

Here is information about development setup and a guide on adding new code.

objax's People

Contributors

aakashkumarnain avatar alexeykurakin avatar anukaal avatar aterzis-google avatar carlini avatar cyugao avatar david-berthelot avatar iamharsha1999 avatar jakevdp avatar joaogui1 avatar kashif avatar kihyuks avatar lberrada avatar matpalm avatar naruto-raj avatar naveen-takvaviya avatar npapernot avatar ntt123 avatar parmarsuraj99 avatar peterjliu avatar sathish-a avatar schien1729 avatar seungjaeryanlee avatar shs037 avatar yechengxi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

objax's Issues

Clarification on loss function names

Hi,

I'm assuming cross_entropy_logits() is meant to be used on the raw logits of the model and sigmoid_cross_entropy_logits() is the version that's meant to be used on logits that have been normalized to [0, 1] with a sigmoid or softmax layer?

Or is that sigmoid_cross_entropy_logits() is meant to be used for binary classification on the raw logits of the model and cross_entropy_logits() is meant for multi-class classification?

Prototype mutable tensors + question: are they desirable?

Currently we cannot do write part of a tensor using numpy syntax:

import jax.numpy as jn

v = jn.arange(10)
v[2:4] += 1  
# TypeError: '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment.
# JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

However in numpy it works

import numpy as np

v = np.arange(10)
v[2:4] += 1  
print(v)
# [0 1 3 4 4 5 6 7 8 9]

I'm interested in prototyping feasibility but I also wonder if it's a desirable feature generally speaking. Feedback welcome.

Benchmarks

Hi!

Great work all! It's really surprising that the core of the library is very readable yet so extensible.

I'm curious if there are any performance benchmarks for comparison with more popular frameworks?

Random number tests fail unpredictably

There are currently tests in tests/testrandom.py that fail with small probability (and some with not-so-small probability)

For example, here

value = np.array(objax.random.normal((1000, 100), mean=0, stddev=2))
self.assertAlmostEqual(value.mean(), 0, delta=0.01)

the probability that this fails is ~11% (taking the mean of 1000*100 samples with man stdev=2 should be distributed like a normal with mean 006.324e-3 and this is greater than .01 with probability 11%).

The other tests also fail with non-negligible probability.

Error in cross_entropy_logits_sparse function

return logsumexp(logits, axis=1) - logits[jn.arange(logits.shape[0]), labels]

This line causes error with the following example:

logits = objax.random.normal([3, 5, 32])
labels = objax.random.randint([3, 5], low=0, high=32)

objax.functional.loss.cross_entropy_logits_sparse(logits, labels)

Error trace

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-69-e9d3e6a2a977> in <module>()
      2 labels = objax.random.randint([3, 5], 0, 31)
      3 
----> 4 objax.functional.loss.cross_entropy_logits_sparse(logits, labels)

5 frames
/usr/local/lib/python3.6/dist-packages/objax/functional/loss.py in cross_entropy_logits_sparse(logits, labels)
     46         (batch,) tensor of the cross-entropies for each entry.
     47     """
---> 48     return logsumexp(logits, axis=-1) - logits[jn.arange(logits.shape[0]), labels]
     49 
     50 

/usr/local/lib/python3.6/dist-packages/jax/numpy/lax_numpy.py in _rewriting_take(arr, idx)
   3584   arr = asarray(arr)
   3585   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
-> 3586   return _gather(arr, treedef, static_idx, dynamic_idx)
   3587 
   3588 # TODO(phawkins): re-enable jit after fixing excessive recompilation for

/usr/local/lib/python3.6/dist-packages/jax/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx)
   3591 def _gather(arr, treedef, static_idx, dynamic_idx):
   3592   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 3593   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   3594   y = arr
   3595 

/usr/local/lib/python3.6/dist-packages/jax/numpy/lax_numpy.py in _index_to_gather(x_shape, idx)
   3737         (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
   3738          not advanced_axes_are_contiguous and idx_pos == 0)):
-> 3739       advanced_indexes = broadcast_arrays(*advanced_indexes)
   3740       shape = advanced_indexes[0].shape
   3741       ndim = len(shape)

/usr/local/lib/python3.6/dist-packages/jax/numpy/lax_numpy.py in broadcast_arrays(*args)
   1416     return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
   1417             for arg in args]
-> 1418   result_shape = lax.broadcast_shapes(*shapes)
   1419   return [broadcast_to(arg, result_shape) for arg in args]
   1420 

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in broadcast_shapes(*shapes)
     78   if result_shape is None:
     79     raise ValueError("Incompatible shapes for broadcasting: {}"
---> 80                      .format(tuple(map(tuple, shapes))))
     81   return result_shape
     82 

ValueError: Incompatible shapes for broadcasting: ((1, 3), (3, 5))

Write guide on how to contribute for new users

Points to cover with an example (could be an example pull request)

  • Write code (style recommendations)
  • Write unit test (in tests/)

When adding/removing APIs

  • Update docs/source files so that the changes shows in documentation.

Adding super_convergence example

Hi team, First, I'm excited about this project. I was recently playing with Objax and have implemented the learning rate finder and cyclic learning rate in my pet project. Just to know if we are looking to add it as a part of the library or just an example would be sufficient? And let me know if I can contribute anything so that I can work on it?

Padding for pooling layers

Now Conv2D allows explicit padding, while pooling layers are not. We need consistent padding types for both convolution and pooling layers.

Explicit padding mode

It looks like objax currently limits padding to one of VALID or SAME. This prevents the ability to use explicit padding and would prevent compatibility with models from PyTorch, Gluon that only support explicit (symmetric) padding without adding extra Pad layers to the model.

It'd be nice to at minimum add the ability to support TF style explicit padding (specify both sides of every dim), the underlying jax conv impl is able to receive a [[0, 0], [pad_beg, pad_end],[pad_beg, pad_end], [0, 0]] spec like other low level TF conv.

Even nicer would be a simplificed, per-spatial dim symmetric values like PyTorch, Gluon [pad_h, pad_w] or just pad . My default for most 2D convnets in PyTorch is to use pad = ((stride - 1) + dilation * (kernel_size - 1)) // 2, which results in a 'same-ish' padding value. This can always be done on top of the full low/high padding sequence above.

Some TF models explicitly work around the limitations of SAME padding. By limitations, I mean the fact that you end up with input dependent padding that can be aysmmetric and shift your feature maps relative to each other in a manner that varies as you change your input size.
https://github.com/tensorflow/models/blob/146a37c6663e4a249e02d3dff0087b576e3dc3a1/research/deeplab/core/xception.py#L81-L201

Possible interfaces:

  • padding : Union[ConvPadding, Sequence[Tuple[int, int]]] (like conv_general_dilated but with the enum for valid/same)

  • Add more modes the enum and associated values for those that need it via a dataclass

class PaddingType(enum.Enum):
  """An Enum holding the possible padding values for convolution modules."""
    SAME = 'SAME'
    VALID = 'VALID'
    RAW = 'RAW'  # specify padding as seq of high/low tuples
    SYM = 'SYM'  # specify symmetric padding for spatial dim as tuple for H, W or single int

@dataclass
class Padding:
    type: PaddingType = PaddingType.SAME
    value: Union[Sequence[Tuple[int, int]], Tuple[int, int], int] = None

    @classmethod
    def same(cls):
        return Padding(PaddingType.SAME)

    @classmethod
    def valid(cls):
        return Padding(PaddingType.VALID)

    @classmethod
    def raw(cls, value: Sequence[Tuple[int, int]]):
        return Padding(PaddingType.RAW, value=value)

    @classmethod
    def sym(cls, value: Union[Tuple[int, int], int]):
        return Padding(PaddingType.SYM, value=value)

Add type and shape check to var assign

I made a few mistakes moving weights from PyTorch with vc.assign and ended up clobbering over all of the models weights with the wrong shape. No errors until you try to use the model.

I noticed the assign fn just uses var = tensor ... no copy option since the jax array's are immutable, but wouldn't an assert isinstance(tensor, JaxArray) and self.var.shape == tensor.shape be appropriate? Or possibly an attempt to convert the type to JaxArray and broadcast the shape...

Am I missing use cases where you'd want to change type away from JaxArray or use a different shape than the original on assign()?

Remove need for .value when referring to internal param values.

When referring to a Module's internal variables in call, one needs to use self.x.value instead of simply self.x. It'd be nice to enable this syntactic sugar to improve the readability of complex math expressions. For example, tf.module allows this

 class Dense(tf.Module):
   def __init__(self, in_features, out_features, name=None):
     super(Dense, self).__init__(name=name)
     self.w = tf.Variable(
       tf.random.normal([in_features, out_features]), name='w')
     self.b = tf.Variable(tf.zeros([out_features]), name='b')
   def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)

Could you outline how to write a simplest RNN Module?

I'm looking for write a basic RNN that does f(Ax+b) at each time step.

What would be the best way to go about it? Could you outline some code to give an idea?

Can one apply JIT over the entire (unrolled) network for training/inference?

Multiple iterations per training loop step?

Hi,

In TF TPU usage a common pattern is to train for multiple steps per train call, https://www.tensorflow.org/guide/tpu#improving_performance_by_multiple_steps_within_tffunction

I was wondering how we can achieve a similar patten with objax. Do we need a custom version of the Jit class?

Along these lines it might be helpful to expand the docs on how these types of things work. Looking at the code, there seems to be some dance with temporarily replacing the module variables with their traced counterparts, and using side effects rather than return values to convey computation results.

Thanks!

Tracer error when using a random variable

Hi, I'm working on implementing dcgan in objax and I'm running into this error:

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: Different traces at same level: Traced<ShapedArray(float32[1,128,1,1])>with<JVPTrace(level=2/1)>
  with primal = Traced<ShapedArray(float32[1,128,1,1])>with<DynamicJaxprTrace(level=0/1)>
       tangent = Traced<ShapedArray(float32[1,128,1,1]):JaxprTrace(level=1/1)>, JVPTrace(level=2/1).```

It looks like this is being triggered by the running mean and stddev avgs of batchnorm layer in the discriminator during the generator update step.

Here's a colab link to a colab notebook to reproduce this: https://colab.research.google.com/drive/1gG0naJz_JbFHQwNxL9jTKeifwzVne6KE?usp=sharing

Also, I'm not sure if this is a good place to look for help on this bug. Would posting this in the Jax discussions page be more appropriate?

Thanks,

Bilal

Document how to set random seed for default generator

Improve documentation around random numbers. By default, objax random functions use the DEFAULT_GENERATOR, which has initial seed set to 0.

Looks like you can set the seed by random.DEFAULT_GENERATOR.seed(1234).

Create example on gradient accumulation

Gradient accumulation is a technique used to simulate large batches that would not fit in hardware. Write an example in examples to demonstrate how to do it.

Explore Objax to Tensorflow conversion

Some other JAX frameworks provide API to convert JAX models into Tensorflow:

  1. https://trax-ml.readthedocs.io/en/latest/notebooks/tf_numpy_and_keras.html#2.-Convert-Trax-to-Keras
  2. https://source.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/jax2tf/examples/stax_to_tf_module.py
  3. https://github.com/google/jax/tree/master/jax/experimental/jax2tf

Such conversion might be useful because Tensorflow allow to save trained models in SavedModel format (which contains both weights and network architecture) to be later used in production settings.

Computing gradients for a generator and discriminator

Hi,

I'm trying to get a working dcgan implementation in objax. Since the discriminator and generator need to be optimized separately, I'm taking the gradients wrt to both modules in two different functions like this:

    def d_loss(x, z):
        d_loss_real = objax.functional.loss.sigmoid_cross_entropy_logits(
            discriminator(x, training=True), jnp.ones([x.shape[0], 1])).mean()

        fake_img = generator(z, training=False)
#        fake_img = generator(z, training=True)
        d_loss_fake = objax.functional.loss.sigmoid_cross_entropy_logits(
            discriminator(fake_img, training=True), jnp.zeros([x.shape[0], 1])).mean()

        d_loss = d_loss_real + d_loss_fake

        return d_loss

    def g_loss(x, z):
        fake_img = generator(z, training=True)
        return objax.functional.loss.sigmoid_cross_entropy_logits(discriminator(fake_img, training=False), jnp.ones([x.shape[0], 1])).mean()
#        return objax.functional.loss.sigmoid_cross_entropy_logits(discriminator(fake_img, training=True), jnp.ones([x.shape[0], 1])).mean()

    d_gv = objax.GradValues(d_loss, discriminator.vars())
    g_gv = objax.GradValues(g_loss, generator.vars())

Would this the preferred way of doing this, or is there a way of returning two values (the loss for both the generator and discriminator) in one function and then computing the gradients of the generator and discriminator loss separately (e.g. d_loss wrt discriminator.vars())?

The code runs, and both parts of the gan seem to train and their losses go down, but the discriminator's loss quickly converges to 0 which I'm guessing is caused by having to comment out the lines above to set training=False (which would prevent batchnorm from using the current batch's stats and might be the cause of the discriminator converging so quickly) since it causes a IndexError: tuple index out of range error.

colab notebook: https://colab.research.google.com/drive/1WTBKHqZWAg-TpXJmZWOn_7VCVOZsDP2F?usp=sharing

Naming Suggestions

This library will likely be quite popular among those used to PyTorch in the research community, as there is a demand for a standard NN interface to jax. Already the module/optimizer design is very familiar to those who have used pytorch, and the developer clearly was inspired by this design structure.

However, there is one aspect that is challenging to understand. In PyTorch (and in Autodiff), variable is synonymous with "value that you take a gradient wrt". In Objax, variable has a different meaning both in TrainVar (Parameter in pytorch) and StateVar (Buffer in pytorch). TrainVar is a variable, whereas StateVar is not. Furthermore Train is confusing as terminology, since it only applies during model fitting, as opposed to during inference.

Given this similarity, it seems like a natural change would be to refer to TrainVar => Parameter and StateVar => State or Buffer. These names avoid the ambiguity of variable, dictate what they do not how they are used, and make the library even more familiar and easy to use for those used to PyTorch.

Naming of the `GradValues` function

If I understand right, GradValues essentially does two things: computing gradients and computing model final values.

So why not split it into two functions? Or if we keep the current form, could we name it GradAndValuesFn? Just thinking this is a prominent function and want to keep it the easiest for people beginning to use the framework. An easy name as fit() and predict() made scikit-learn.

Is it desirable to have an objax.Function module to wrap functions and their vars?

When passing a function to Jit or Grad, etc... we need to also pass the variables the function uses. While we can maintain this design pattern, we could also propose an alternative: a decorator that turns a function into a module.

Example:

# Current design
def some_function(x):
    return objax.functional.softmax(model(x, training=False))

jit_func = objax.Jit(some_function, model.vars())

Alternate design proposal (not a replacement)

@objax.Function(model.vars())
def some_function(x):
    return objax.functional.softmax(model(x, training=False))

jit_func = objax.Jit(some_function)

I like a certain elegance about this proposal but I would like more opinions on whether it is a desirable addition.

losses and metrics

Are we looking to add losses and metrics? If yes, I can start working on it

Many layer conventions are different from both Tensorflow/Keras and PyTorch

I'm not sure if this was by design, or just the way it worked out, but Objax forges in own path for many parameter/variable, state naming/order conventions, default arguments etc.

While it may seem trivial if one is just using Objax, but adds some cognitive overhead when you're working with multiple frameworks, moving weights around, some components that work with multiple modeling interfaces.

Some examples:

Objax BatchNorm eps/momentum defaults are diff from both TF/Keras and PyTorch. Momentum isn't that important unless training, but eps impacts existing weight compatibility if not matched.

objax: eps=1e-6, momentum=.999
tf/keras: eps=1e-3, momentum=.99 or .999
pytorch: eps=1e-5, momentum=.1 (.9)

Lots of layer variable names and their creation orderings are different.

Conv2d

  • Objax .b and .w
  • TF/keras .kernel and .bias
  • Pytorch .weight and .bias

BatchNorm

  • Objax - .running_mean, .running_var .beta, .gamma
  • TF/keras .moving_mean, .moving_variance, .gamma, .beta
  • Pytorch .weight, .bias, running_mean, running_var

Ordering of variables, if iterating over in creation order Objax is often bias first, almost every other framework I'm used to usually has weight (or gamma/scale/etc equiv) then bias in creation order (and thus often iteration order).

I'm not sure if any of this is still in flux, if so, could help to align more conventions with one of the existing options.

Training state as a Module attribute

As mentioned in a Twitter thread, I am curious about the decision to propagate training state through the call() chain. From my perspective this approach adds more boilperplate code, and more chance of making a mistake (not propagating the state to a few instances of a module with a BN or dropout layer, etc). If the state changed every call like the input data, it would make more sense to pass it with every forward, but I can't think of cases where that is common? For small models it doesn't make much difference, but as they grow with more depth and breadth of submodules, the extra args are more noticeable.

I feel one of the major benefits of an OO abstraction for NN is being able to push some attributes like this into the class structure vs forcing it to be forwarded through every call in a functional manner. I sit in the middle ground (pragmatic) of OO vs functional. Hidden state can be problematics, but worth it if it keeps interfaces clean.

Besides TF/Keras, most DL libs managetraining state as module attr or some sort of context

It should be noted that Swift for TF started out Keras and objax like with the training state passed through call().

Disclaimer: I like PyTorch, I do quite a bit of work with that framework. It's not perfect but I feel they really did a good job in terms of interface, usibility, evolution of the API. I've read some other comments here and acknowledge the 'we don't want to be like framework/lib X, or Y just because. If you disagree go fork yourself'. Understood, any suggestions I make are not just to be like X, but to bring elemtents of X that work really well to improve this library.

I currently maintain some PyTorch model collections, https://github.com/rwightman/pytorch-image-models and https://github.com/rwightman/efficientdet-pytorch as examples. I'm running into a cost ($$) wall with experiments supporting my OS work and experiments re GPU. TPU costing is starting to look far more attractive. PyTorch XLA is not proving to be a great option but JAX with a productive interface looks like it could be a winning solution with even more flexibility .

I'm willing to contribute code for changes like this, but at this point it's matter of design philosophy :)

Question: difference between Flax and Objax

This is just a question as I cross over Objax and Flax today and both are google frameworks and both claim to be the deep learning framework in Jax. I noticed that the API is slightly different... Could you share some thoughts on the relation of these frameworks or eventual cooperation in a single user project?
Thank you 🐰

"objax.variable.VarCollection is not a valid JAX type" when creating a custom optimizer

Hi, I wish to create a custom optimizer to replace the
opt(lr=lr, grads=g)
line in the example https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py

Instead, I replaced it with

for grad, p in zip(g, model_vars):
      p.value -= lr * grad   

and then supplied model.vars() as an argument to train_op. However, I received an error: objax.variable.VarCollection is not a valid JAX type. Can someone help me with this issue? Here is a minimal working example which reproduces the error.

import random
import numpy as np
import tensorflow as tf
from objax.zoo.wide_resnet import WideResNet

# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.transpose(0, 3, 1, 2) / 255.0
X_test = X_test.transpose(0, 3, 1, 2) / 255.0

# Model
model = WideResNet(nin=3, nclass=10, depth=28, width=2)
#opt = objax.optimizer.Adam(model.vars())
predict = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)),
                    model.vars())
# Losses
def loss(x, label):
    logit = model(x, training=True)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

gv = objax.GradValues(loss, model.vars())

def train_op(x, y, model_vars, lr):
    g, v = gv(x, y)
    for grad, p in zip(g, model_vars):
      p.value -= lr * grad   
    return v


# gv.vars() contains the model variables.
train_op = objax.Jit(train_op, gv.vars()) #I deleted opt.vars()

for epoch in range(30):
    # Train
    loss = []
    sel = np.arange(len(X_train))
    np.random.shuffle(sel)
    for it in range(0, X_train.shape[0], 64):
        loss.append(train_op(X_train[sel[it:it + 64]], Y_train[sel[it:it + 64]].flatten(), model.vars(), 4e-3 if epoch < 20 else 4e-4)) #I added model.vars() 

Add Transformer to objax.zoo

Transformer is very popular network and it would be great to have one in objax.zoo.
We should probably also add an attention layer to objax.nn too.

Easy support for per-module train/eval state

This is tracking issue to improve support of per-module train/eval state in Objax.
This issue originates from discussion of PyTorch vs Objax-style of propagating training/eval mode #29

PyTorch style allows easy way to specify per-module train/eval mode, like in the following example:

model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.train()
model.block_1.bn_1.eval()
model.block_2.bn_2.eval()

There is no clean way to achieve the same thing in Objax right now. One possibility is to use functools.partial:

model = Resnet50(nclasses=1000)
…
# Here is example how to force certain batch norms into eval mode
model.block_1.bn_1 = functools.partial(model.block_1.bn_1, training=False)
model.block_2.bn_2 = functools.partial(model.block_2.bn_2, training=False)

# following line will call model in training mode, except for two block_1.bn_1 and block_2.bn_2
y = model(x, training=True)

However there are some problems with functools.partial:

  • it converts everything into a function (thus vars() are not propagated).
  • if bn_eval = functools.partial(bn, training=False) and caller will try to pass training argument to bn_eval it will cause run-time error
  • there is no easy way to undo functools.partial after it applied to a module

Thus we need a better solution to do per-module train/eval state

More control over var/module namespace.

I got my first 'hello world' model experiment working w/ Objax. I adapted my PyTorch EfficientNet impl. Overall pretty smooth, currently wrapping Conv2d so I can get the padding I want.

One thing that stuck out after inspecting the model, the var namespace is a mess. An aspect of modelling that I value highly is the ability to have sensible checkpoint/var maps to work with. I often end up dealing with conversions between frameworks, exports for mobile or embedded targets and having your vars (parameters) sensibly named, and often being able to control those names in the originating framework is important.

Any thoughts on improving this? The current name/scoping mechanism forces the inclusion of the Module class names, is that necessary? Shouldn't attr names through the tree be enough for uniqueness?

Also, there is no ability to specify names for modules in sequential containers. I use this quite often for frameworks that have it. Sometimes I don't care much (long list of block repeats, 0..n is fine), but for finer grained blocks I like to know what conv is what by looking at the var names. '0.b, o.w' etc isn't very useful.

I'll post an example of the var keys below, and comparison point for pytorch.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.