axon's Issues

Add Recurrent Layers

Require solution for managing layers that maintain state / return multiple outputs

  • gru
  • lstm
  • conv_lstm

Support custom layers

Modularity of layers is easy because we can use regular Elixir functions, but we need a solution for specifying and using custom trainable parameters in layers.

Re-use subgraphs in compilation of combinators

Currently combinators like add, concatenate, etc. traverse back up entire subgraphs and treat them as completely different parts of the computation graph. For example:

|> dense(128)
|> add(x)

x entire subgraph will appear in the resulting expression twice, even though it's actually the same thing. This will lead to extremely large expressions in complex models and other possible issues. Additionally, for #28, we would end up returning multiple independent graphs, even if the base of the model shares the same subgraph.

Support for non-DL models and unsupervised learning?

Is Axon only going to be focused on DL approaches to machine learning, or should it also include non-DL supervised learning approaches that leverage labeled datasets and could leverage Nx like SVM, decision trees, random forests, ensembles, etc?

Along the same lines, what about unsupervised learning approaches?

I guess what it comes down to, is Axon being strictly positioned against DL frameworks like TF, Keras, PyTorch etc or do we want it to encompass other statistical ML approaches like Scikit-learn or Shogun ML?

Add shape / type assertions

We should add shape / type assertions to layers to provide possibly more specific error messages than falling back on what Nx may give

Functionality Roadmap

An issue to track some baseline functionality:


  • celu
  • elu
  • exp
  • gelu
  • hard_tanh
  • hard_sigmoid
  • hard_silu/hard_swish
  • leaky_relu
  • log_sigmoid
  • relu
  • relu6
  • selu
  • sigmoid
  • silu
  • softmax
  • softplus
  • softsign
  • tanh


  • glorot_uniform
  • glorot_normal
  • he_normal
  • he_uniform
  • lecun_uniform
  • lecun_normal
  • normal
  • ones
  • orthogonal - requires elixir-nx/nx#174
  • uniform
  • zeros

Loss Functions

  • binary_crossentropy
  • categorical_crossentropy
  • categorical_hinge
  • cosine_similarity - requires elixir-nx/nx#174
  • ctc
  • hinge
  • kl_divergence
  • log_cosh
  • margin_ranking
  • mean_absolute_error
  • mean_squared_error
  • poisson
  • soft_margin


  • accuracy
  • mean_squared_error - requires defndelegate
  • mean_absolute_error - requires defndelegate
  • precision
  • recall
  • sensitivity
  • specificty


Optax style transformations:

  • scale
  • scale_by_adam
  • scale_by_rss
  • scale_by_belief
  • scale_by_rms
  • trace
  • clip
  • clip_by_global_norm
  • centralize
  • scale_by_trust_ratio
  • scale_by_schedule
  • scale_by_radam
  • scale_by_stddev


  • polynomial_schedule
  • exponential_decay_schedule
  • cosine_decay_schedule
  • constant_schedule


For now, just functional implementations resembling torch.nn.functional or tf.nn:

Linear Layers

Convolutional Layers

  • conv
  • conv_transpose
  • depthwise_conv
  • separable_conv2d
  • separable_conv3d

Pooling Layers

  • avg_pool
  • max_pool
  • lp_pool
  • adaptive_avg_pool
  • adaptive_max_pool
  • adaptive_lp_pool
  • global_avg_pool
  • global_max_pool
  • global_lp_pool

Normalization Layers

  • batch_norm
  • group_norm
  • instance_norm
  • layer_norm

Dropout Layers

  • dropout
  • alpha_dropout
  • feature_alpha_dropout
  • spatial_dropout

Attention Layers

  • dot_product_attention - requires elixir-nx/nx#182
  • additive_attention - requires repeat/gather on Nx

Visual Layers

  • resize

We can drop off dimensional suffixes in favor of generic implementations too.

Introduce a high-level layer API

The next step after #1 is to implement higher-level constructs on top of the lower-level functional implementations. The goal of the higher-level API is to provide abstractions for building neural networks that:

  • limit the overhead of writing a neural network from scratch
  • are easy to understand, especially for beginners
  • are flexible enough for more complicated architectures (ResNets, GANs, etc.)
  • can be represented as composition of Nx functions for low-level JIT/AOT compilation OR can be represented as higher-level constructs for compilation using specific NN compilers (ONNX, TFLite, etc.)

For simplicity, we'll leave the discussion of efficient handling of network state to a later issue. This issue will only focus on the architecture/representation of a network.

Axon Struct

We will introduce an %Axon{} struct that represents a constructed network. For now, the struct will have the following attributes:

  • :input - input to this layer/model/etc., or in the case of a literal input layer some metadata
  • :shape - shape of the layer's parameters, can be inferred from :input, we can also allow input layer shapes to have nil batch dimensions to represent arbitrary sized batches
  • :transformation - how does this layer transform the input, can be an atom (like :dense) which resolves to an already implemented layer, or a numerical definition for arbitrarily complex transformations
  • other - there's a lot of other metadata that should be included in some way: :initializer and initializer options, :activation and activation options, layer specific options, possibly constraints, callbacks, etc.

The struct would be build up with calls to high-level functions in the root Axon namespace. For example, MNIST:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

Then a model would be compiled, perhaps like:

compiled_model = Axon.compile(model, options)

At a minimum Axon.compile initializers parameters and returns a compiled function using whatever backend the user specifies. It can also be abstracted away in some higher-level training logic, but that's discussion for another issue.

Structuring the network in this way makes arbitrary compilation easy, the API is simple and easy to understand, and flexible enough for complex models. For example, a GAN:

generator =
   Axon.input({nil, 128})
   |> Axon.dense(256, activation: :tanh)
   |> Axon.dense(512, activation: :tanh)
   |> Axon.dense(784, activation: :tanh)
   |> Axon.reshape({28, 28})

discriminator =
   Axon.input({nil, 784})
   |> Axon.dense(128, activation: :relu)
   |> Axon.dense(1, activation: :sigmoid)

combined =
  |> Axon.compose(discriminator)


High-level layers are tied directly to their functional implementations in Axon.Layers. Some of them have layer specific options which can be passed during layer creation.


We will use combinators similar to: to represent more complex relationships between layers. At a minimum we'd have:

  • compose - function composition
  • add - adds layers
  • concat - concats layers
  • residual - residual output
  • parallel/split - something to represent multiple-model outputs

Implement model inspection

For those familiar with Keras model.summary(), it can be useful to see what your compiled model looks like in terms of shape at each layer, number of trainable parameters, layer names, etc. Because a model is just an Axon struct, we can implement something similar through the Inspect protocol. I personally like Keras style model summaries; however, I'm open to other ideas about how to render a model summary during inspection.

Implement a high level optimization API

The CIFAR example is updated and demonstrates the usage of the low-level constructs in updates.ex to create more advanced optimizers. What's required is essentially the same as what we've already implemented with the layer API. We need to construct optimizer combinators that:

  1. Accept a model (and possibly hyperparameters)
  2. Initialize state w.r.t each parameter in the model
  3. Apply updates according to some transformations defined in updates.ex

I propose we follow an approach similar to the one taken for the layer API and have the following macros in an Axon.Optimizer namespace:

  • init(optimizer, model, opts \\ []) - initializes the optimizer with state (e.g. first and second moment of an update)
  • apply_updates(optimizer, optimizer_state, gradients, params) / apply_gradients / step - applies updates and returns new parameters and updated optimizer state

With this approach, we can implement common optimizers as regular Elixir functions and then users can apply them trivially from within defn and def. I think this leaves us with room to play. It may be that in the future optimizers get taken out of Axon and placed in a separate library similar to optax.

An alternative approach is implement optimizers as behaviours, although this falls a bit short because we still need to pattern match on inputs, and then we'd have to implement macros for each "behaviour".

Add penalties for parameter regularization

Lower level should be supported with Axon.Regularization or Axon.Penalty module with L1/L2 penalties implemented as defn. Penalties can be added in custom objective functions passed to step or through a high-level interface in step

Unify dropout layers

As of now we have 6 (soon to be 7) dropout layers:

  • dropout - general dropout
  • spatial_droput1d, spatial_dropout2d, spatial_dropout3d - spatial dropout, which masks across entire input feature channels
  • alpha_dropout - rather than 0 masking, computes a mask that maintains the mean and standard deviation of the input
  • feature_alpha_dropout - pretty much like spatial dropout, but masks with negative selu instead of 0
  • dropblock - computes a mask with contiguous regions across feature channels (e.g. this will mask large chunks of pixels in an image rather than random pixels)

We can extract the following pattern from each of these dropout layers:

  • rng_state - not maintained in our current implementation, discussion for another issue
  • noise_shape - spatial layers compute the noise shape such that masked layers will be implicitly broadcasted across feature channels
  • mask - how to compute the value of the mask
  • shift_and_scale - regular dropout "scales" by (1 / (1 - rate)), alpha dropout shifts and scales to maintain mean/variance

Knowing this, I propose the following generalized dropout method:

defn dropout(input, opts \\ []) do
  opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mask_values: 0.0, gamma: 1.0, beta: 0.0])
  mask = Nx.less(Nx.random_uniform(noise_shape), 1 - rate)
  x =, input, mask_values)
  scale_and_shift(x, gamma, beta)  

Dropblock will more than likely need to be considered separately as it requires some advanced indexing, but this will simplify the layer API and allow users to explore custom dropout layers.

Validate input shape as part of model compilation

Invalid input shapes can lead to confusing/surprising error messages. Trivial example:

model = Axon.input({nil, 1, 32}) |> Axon.max_pool()
input = Nx.random_uniform({1, 32}) # oops, forgot a dimension

{init_fn, predict_fn} = Axon.compile(model)
predict_fn.(init_fn.(), input)

Results in:

** (ArgumentError) invalid window dimensions, rank of shape (2) does not match rank of window (3)

Which can be confusing. We should raise a clear error if the input shape/rank is incorrect.

Parameters declared out of order initialize incorrectly

This stems from our current (fragile) way of ensuring parameters are in the correct order by generating a unique ID for each parameter like:

System.unique_integer([:positive, :monotonic])

and then sorting on the ID because they are guaranteed to be ordered. Internally this isn't that much of a problem, but with the addition of custom layers, it's possible to run into the following case:

bias = Axon.param("bias", {}, initializer: :zeros)
weight = Axon.param("weight", {}, initializer: :ones)

Axon.layer(x, fn x, w, b -> Nx.add(Nx.multiply(x, w), b) end, output_shape, [weight, bias])

During initialization, weight will be initialized as bias and bias will be initialized as weight. When they are the same shapes as in this instance, I think this will lead to silent bugs. We could document, but I believe our current method is fragile and needs to be refactored altogether.

We need a better way to track parameters at each layer. This is really an issue in the sense that we can't just declare a parameter and use it in an operation like:

def dense(%Axon{output_shape: shape} = x, units) do
  w = Axon.param("weight", {elem(shape, 1), units)
  b  = Axon.param("bias", {1, units})
  Nx.add(, w), b)

So we need a better way to ensure parameters are initialized and used in the correct places. I believe the best option is to ensure each parameter has a unique name, and then pass params as a map to defn - which would make this an upstream issue.

Remove dimensional suffixes

After some deliberation I have decided that it's best to drop dimensional suffixes to simplify the API. They're not necessary with most tensor compilers (optimal kernels are generated/compiled based on shape anyway) and I believe it makes more sense to settle for a simpler API. So:

conv1d, conv2d, conv3d -> conv
...and so on...

Create mechanism for easy model composition

For now, we'll only consider how this should work in the model creation and execution API, but it will touch the training API as well.

Consider the models in a basic GAN:

generator =
  Axon.input({nil, 100})
  |> Axon.dense(128, activation: :tanh)
  |> Axon.dense(512, activation: :tanh)
  |> Axon.dense(784, activation: :tanh)
  |> Axon.reshape({1, 28, 28})

discriminator =
  Axon.input({nil, 1, 28, 28})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

In order to train, what you'd want to do is something like:

combined = compose(discriminator, generator)  # represents D(G(input)) 
step_d = Axon.Training.step(discriminator, :binary_cross_entropy, Axon.Optimizers.sgd(0.005)
step_g = Axon.Training.step(combined, :binary_cross_entropy, Axon.Optimizers.adam(0.01)

And then you can alternate using step_d and step_g to train on valid / fake images. Unfortunately, we currently don't support model composition in this sense - you can define functions generator and discriminator without an input block, but there's no way to cleanly determine which parameters belong to which model. Ideally, you'd be able to compose models in some way so that when you initialize, predict, train, etc. parameters are grouped:

combined = compose(discriminator, generator)
{d_params, g_params} = combined_params = Axon.init(combined)
Axon.predict(combined, combined_params)

{{d_params, g_params}, _} =
  |> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.adam(0.01)
  |> Axon.Training.train(inputs, targets)

Whatever the implementation is, it will involve adding some metadata to parameters to express that expresses their ownership to a given model. From an API perspective, one option is to introduce Axon.compose for composing Axon structs into a single model while preserving parameter information, although I'm not sure I love that right now.

Model inference mode

Similar API for PyTorch:

Keras inference mode is a property of the model (training=False or trainable=False I can't remember).

Unfortunately, simply doing:

{init_fn, predict_fn} = Axon.compile(model)

is not really good enough. There are important differences in behaviors of certain layers in inference mode including:

  • Dropout is disabled
  • BatchNorm uses the EMA of mean/variance calculated over the course of training

We may want to adjust mixed precision in inference mode as well.

It'd be nice to have an API that both "freezes" the parameters into the model, drops training-only behavior in the forward pass, and maybe performs some slight optimizations to the forward expression. One option is to introduce an option to compile called :mode which specifies inference versus training mode - although I'm not sold on it yet.

One thing to note on "freezing" parameters or inlining them into the forward pass, this pattern should probably discouraged or at the very least noted as possibly harmful:

{init_fn, predict_fn} = Axon.compile(model)

params = init_fn.()
inference_fn = &predict_fn.(params, &1)

While you now have a function with parameters inlined as constants, if you can't guarantee the shape of your inputs is consistent (e.g. the batch size is the same), subsequent calls to inference_fn with different shapes or types will load a new potentially very large executable on to your device, quickly leading to OOM. An option around this is to abstract the parameters away by placing them on the device somewhere and holding a reference to them in some global state.

Axon.Training.train gives an error

Having a small issue. I try to run Axon.Training.train and it gives an error

function is undefined or private. Did you mean one of:

  * dot/2
  * dot/6

Could you please help.


Add high-level layers

Missing so far:

  • conv_transpose
  • transpose
  • reshape (convenience to ignore batch dimensions)
  • pad (convenience to ignore batch dimensions)
  • concatenate
  • add
  • subtract
  • multiply

Other layers tracked in #1.

Add more examples

Willing to accept examples on different datasets and models to demonstrate different parts of the Axon API and to demonstrate Axon's viability in the ecosystem. The TensorFlow guides are a great place to look for different datasets and problems. If you're blocked on any specific issue feel free to comment on the relevant issue with your use case :)

Model inspection does not enforce correct layer ordering

This also has some implications for compiling layers like add which reference entire subgraphs. The current inspection traverses the entire subgraph and layers are displayed out of order for complex models. See examples/resnet.exs for example.

Integrate validation and testing into training API

More than likely this can be almost identical to the PyTorch approach where we define validation_step and test_step and then we can optionally include validation in with training and then testing with a separate Axon.test

Add model import/export API

Need ability to serialize models to/from external formats. Model serialization is serialization of the actual computation graph. We should also have the ability to save and load model parameters, but I believe part of that discussion needs to happen upstream with a common Nx tensor serialization format. See e.g. elixir-nx/nx#354

Add additional optimizers and updates


  • lamb
  • yogi
  • noisy_sgd
  • fromage
  • adamw


  • scale_by_yogi
  • add_decayed_weights
  • scale_by_trust_ratio
  • add_noise

Also requires changing update functions from update(updates, state) to update(updates, params, state)

Allow options in activation functions

Some activation functions support options (e.g. LeakyReLU supports an alpha option). As of now there's no way to include this using the high-level API.

Integrate named tensors

On the high-level API, we can integrate named tensors by specifying the expected names on input:

Axon.input(batch: nil, channels: 3, height: 224, width: 224)

Then we'll need to consider how these are transformed through the network. For other layers, we could consider adding an option to specify output features as keywords:

Axon.input(batch: nil, pixels: 784)
|> Axon.dense(features: 128, activation: :relu)
|> Axon.dense(label: 10, activation: :softmax)

Unify normalization layers

The current API has 4 normalization layers:

  • Batch Normalization
  • Instance Normalization
  • Group Normalization
  • Layer Normalization

All of these implementations are built on a fundamental formula:

defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
  opts = keyword!(opts, epsilon: 1.0e-6)
  scale =
    |> Nx.add(opts[:epsilon])
    |> Nx.rsqrt()
    |> Nx.multiply(gamma)

  |> Nx.subtract(mean)
  |> Nx.multiply(scale)
  |> Nx.add(bias)

But differ in how the compute the mean and variance across the input:

  • Batch Normalization - calculated for each individual channel across all samples and spatial dimensions.
    • reduction_axes: [:batch, :height, :width, ...]
  • Instance Normalization - calculated for each individual channel for each individual sample across both spatial dimensions.
    • reduction_axes: [:height, :width, ...]
  • Layer Normalization - calculated for each individual sample across all channels and both spatial dimensions.
    • reduction_axes: [:channels, :height, :width, ...]
  • Group Normalization - calculated across groups of channels and both spatial dimensions for the given group size.
    • reduction_axes: [:groups, :height, :width, ...] (after some reshaping to get :groups)

Additionally, some of these layers are stateful (batch/instance norm) and some are stateless (layer/group norm). Stateful normalization layers return the transformed input and a running average mean and variance adjusted with momentum, relying on the state to compute the next iteration of normalization. Stateless normalization layers return just the transformed input.

In order to unify these normalization layers under the lower-level functional API, rather than have individualized functions for each layer we will instead have:

In the layers API:

  • normalize - see above

In a separate module:

  • batch_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • instance_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • group_norm_stats(input, opts \\ []) - returns {mean, var}
  • layer_norm_stats(input, opts \\ []) - returns {mean, var}

In a separate module (probably an updates.ex or something that has gradient/parameter transforms):

  • ema(x, momentum) - returns a scaled x, exponential moving average

I think this limits code reuse and still enables us to easily build these normalization layers into a high level API

Introduce a training API

Given the core components implemented in #1, we can implement an efficient, simple, but flexible training API. I am proposing an API similar to under the Axon.Training namespace that represents a general supervised training pipeline for models.

Training Behaviour

We can consider the training loop to take the following inputs:

  • model_state - parameters, discussed in a future issue for state management and model initialization
  • optimizer - encapsulates both optimizer state, and the update step, discussed in a future issue
  • train_objective (note I'm not using Task to avoid confusion with Elixir tasks) - an objective (loss) function parameterized by the input model such that grad(model_state, objective) differentiates the model parameters w.r.t input model
  • eval_objective - metrics for evaluating model performance on validation sets, loss, accuracy, mse, mae, etc. and some associated state for monitoring training proress
  • dataset - inputs and labels
  • options - miscellaneous

and to perform the following algorithm (this is half pseudocode, half Elixir):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
    evaluate(model_state, eval_objective)

It's common to use metrics as an easy way to monitor training, so we can introduce a metrics object which encapsulates metric state and metric evaluation functions:

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)
    evaluate(model_state, eval_objective)

We can further extend this API with before_x and after_x callbacks (writing checkpoints, plotting graphs, etc.):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do

    for {input, target} <- dataset do

      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)

    evaluate(model_state, eval_objective)


For more flexibility, we can extract each train step into a method, this facilitates easier writing of custom training loops:

def train_on_batch(batch, model_state, train_objective, optimizer) do

  gradients = grad(model_state, train_objective(model_state, batch))
  update(model_state, gradients, optimizer)
  metrics(train_objective(model_state, batch))


def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]
  steps = options[:steps] || :unlimited # until batch is empty

  for i <- 0..epochs do

    for batch <- dataset, until: steps do
      train_on_batch(batch, model_state, train_objective, optimizer, train_objective)
    evaluate(model_state, eval_objective)


Given this framework, the training API would have at a minimum the following callbacks:

defmodule Axon.Training do  
  # Runs before each epoch
  @callback before_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each epoch
  @callback after_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs before each batch
  @callback before_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each batch
  @callback after_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs a single train step, this can also be `defn` for working with infeed/outfeed
  @callback train_on_batch(batch, model_state, train_objective, optimizer) :: model_state

  # Runs a training loop to convergence
  @callback train(model_state, optimizer, train_objective, eval_objective, dataset, options) :: {:ok, model_state} | {:error, reason}

I left a lot of key pieces out because I believe this motivates discussion about how to best separate concerns between modules to provide both maximum flexibility, as well as ease-of-use.


The spirit of autograd is that writing a machine learning model is as simple as defining a differentiable objective function. I believe that is a principle we should stick to, so I've separated the idea of objective into what could be a separate module, behaviour, function, etc. Objectives need to encapsulate both evaluation objectives and training objectives. They need to be capable of supporting parameterization by a model. I think they should also contain information about associated metrics and evaluation criteria that's tracked during training. Objectives could possibly be defined as a behavior with two methods: predict and loss where loss depends on predict and predict represents a model definition. I'm not sure I really like that idea, but objectives definitely deserve a well-thought out discussion in a separate issue.

Optimizers and Updates

From a design standpoint, updates and optimizers should be included separately. However, from a performance standpoint, I think you might want to fuse gradient calculation with updates, but I believe this could be possible by silently wrapping both update and the grad(objective) in another defn somewhere because defn calls are inlined and compiled. Optimizers as separate modules is a pretty common pattern, so I would go for a behaviour here with common implementations built on the primitive updates.ex.


There is a lot of state to keep track of in the above example: model state, optimizer state, metric state, evaluation state, etc. I think it makes sense to wrap state into a common API, so stateful parameters can be flexibly handled. Another advantage of implementing this is we can limit assumptions about actual state management solutions in practice. So users can choose to implement their own if they so choose.


The above just lists a dataset as containing batches. I would basically try to represent this as a stream that can be consumed. I don't think dataset implementations fall in this library, but I think Axon should enforce some standard for what datasets look like.


I believe this lays out a plan for integrating higher-level APIs moving forward. Obviously this is incredibly general because it inherently requires implementation details from the unincluded aspects listed above. However, I believe starting with a training API to make sense of how to split up the rest of the work makes sense.

Ability to disable/limit training messages

While training messages are helpful, in some cases it would be nice to be able to either limit the verbosity or disable the output completely.

Use cases:

  • When running in a Docker container using an external IDE like VSCode, the stream of per-batch messages bogs down training run output, in turn bogging down the runtime of the app (e.g. MNIST takes ~2s/epoch running directly within a Docker container, ~20-30s/epoch running as a remote container against GPUs). In this case per-epoch data would be great, per-batch not so much
  • When running something like an RL model where you're more interested in logging the completion status, state and rewards of an agent per traversal than the incremental loss/accuracy/validation, it's often desirable to disable logging completely.

Maybe being able to set reporting level like :per_batch, ':per_epoch or :none?

Implement dynamic unrolling of RNNs

We currently unroll the RNN at compile-time rather than compiling RNNs using a loop. Statically unrolling can be more efficient for short sequences at the expense of more memory consumption; however, we will need the ability to dynamically unroll. Requires elixir-nx/nx#122

Support logging in training API

Somewhat related to #21

PyTorch Lightning supports building custom loggers for integration with third-party logging tools (like TensorBoard). We should include a similar API so training can be monitored in tools like TensorBoard.

Issues running mnist examples.

Trying to evaluate examples/mnist.exs or notebooks/mnist.livemd fails on last (training) step with error.

Environment was just installed:

MacBook Pro (16-inch, 2019) Intel based.

Erlang/OTP 24 [erts-12.0.2] [source] [64-bit] [smp:16:16] [ds:16:16:10] [async-threads:1] [jit] [dtrace]
Elixir 1.12.1 (compiled with Erlang/OTP 24)
Livebook 0.1.2

examples/mnist.exs output:

 Layer                    Shape        Parameters
 input_7 ( input )        {nil, 784}   0
 dense_10 ( dense )       {nil, 128}   100480
 relu_11 ( relu )         {nil, 128}   0
 dropout_12 ( dropout )   {nil, 128}   0
 dense_15 ( dense )       {nil, 10}    1290
 softmax_16 ( softmax )   {nil, 10}    0

15:04:33.745 [info]  XLA service 0x7f855952ee20 initialized for platform Host (this does not guarantee that XLA will be used). Devices:

15:04:33.745 [info]    StreamExecutor device (0): Host, Default Version
** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {32, 128}
    (nx 0.1.0-dev) lib/nx.ex:1170: Nx.to_tensor/1
    (nx 0.1.0-dev) lib/nx.ex:2561: Nx.element_wise_pred_op/3
    (axon 0.1.0-dev) lib/axon/layers.ex:1098: anonymous fn/1 in Axon.Layers."__defn:dropout__"/2
    (axon 0.1.0-dev) lib/axon/layers.ex:1095: Axon.Layers."__defn:dropout__"/2
    (axon 0.1.0-dev) lib/axon/compiler.ex:218: anonymous fn/7 in Axon.Compiler.recur_predict_fun/3
    (axon 0.1.0-dev) lib/axon/compiler.ex:196: anonymous fn/4 in Axon.Compiler.recur_predict_fun/3
    (axon 0.1.0-dev) lib/axon/training.ex:133: anonymous fn/6 in Axon.Training.step/4
    (nx 0.1.0-dev) lib/nx/defn/grad.ex:15: Nx.Defn.Grad.transform/3

notebooks/mnist.livemd output:

** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {32, 10}
    (nx 0.1.0-dev) lib/nx.ex:1170: Nx.to_tensor/1
    (nx 0.1.0-dev) lib/nx.ex:2561: Nx.element_wise_pred_op/3
    (axon 0.1.0-dev) lib/axon/shared.ex:21: anonymous fn/1 in Axon.Shared."__defn:assert_shape!__"/2
    (nx 0.1.0-dev) lib/nx/defn/compiler.ex:307: Nx.Defn.Compiler.__remote__/4
    (axon 0.1.0-dev) lib/axon/losses.ex:164: Axon.Losses."__defn:categorical_cross_entropy__"/3
    (axon 0.1.0-dev) lib/axon/training.ex:134: anonymous fn/6 in Axon.Training.step/4
    (nx 0.1.0-dev) lib/nx/defn/grad.ex:15: Nx.Defn.Grad.transform/3
    (axon 0.1.0-dev) lib/axon/training.ex:80: anonymous fn/7 in Axon.Training.step/3

