Git Product home page Git Product logo

Comments (18)

marcvanzee avatar marcvanzee commented on April 27, 2024 1

I think both your suggestions are good insights. I guess the stochastic initialize_carry is a bit less crucial (it is a most a bit more cumbersome in the current way), but the initialization problem seems like something users can run into more often. I will ask around how people usually handle this, and see if this can be improved, or if we can have clearer guidelines around how to best initialize such variables. I guess at least we should have some example. I'll let you know when I have an answer.

from flax.

marcvanzee avatar marcvanzee commented on April 27, 2024 1

I'm closing this issue since most important questions seem to have been answered. David: if you have any other specific concerns / questions, can you please open a new issue? Thanks!

from flax.

jheek avatar jheek commented on April 27, 2024

Indeed there is a mistake in the example and the test is correct.
We also have a full LSTM example in the pipeline which should be merged soon.

As for your use case which if I understand correctly is the batch size 1:
I think you could try to initialize the carry with initialize_carry(rng, (1,), 5) and it should probably work when using initialize_carry(rng, (), 5) as well.

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

Thanks, you're correct both of those methods work for initialize_carry

I'm having issues with scan though. If I use the pure python definition of scan from the jax docs the code below works. But if I use the jax implementation I get an exception which I've quoted below.

I'm using jax 0.1.59 and jaxlib 0.1.39 with tensorflow 2.1.0

import jax
import flax
import numpy as onp
import jax.numpy as jnp

rng = jax.random.PRNGKey(0)

x = jnp.ones((365, 96, 2))
c0 = flax.nn.recurrent.LSTMCell.initialize_carry(rng, (), 2)
print(f"c0: {c0}")

(c, y), lstm = flax.nn.recurrent.LSTMCell.create(rng, c0, x[-1][0])
print(f"c: {c}")

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, jnp.stack(ys)

timeseries = x[0,:,:]
scan(lstm, c0, timeseries)

scan = jax.lax.scan
scan(lstm, c0, timeseries)

Traceback (most recent call last):
File "/home/david/dev/flax/examples/rnn/train.py", line 29, in
scan(lstm, c0, timeseries)
File "/home/david/.pyenv/versions/3.7.1/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 791, in scan
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
File "", line 2, in hash
TypeError: unhashable type: 'dict'

from flax.

jheek avatar jheek commented on April 27, 2024

A Model instance currently fails when passed into a function like jit or scan. I have a fix for this but for now you could use scan(lambda c, x: lstm(c, x), c0, timeseries).

The reason this fails is because Jax tries to cache the functions it transforms but Model has parameters arrays which are not hashable. I will implement hash in a separate change and at a test for this.

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

Thanks yes that works.

PS I'm also seeing this XLA warning

2020-02-19 08:51:51.429909: W external/org_tensorflow/tensorflow/compiler/xla/service/hlo_pass_fix.h:49] Unexpectedly high number of iterations in HLO passes, exiting fixed point loop.

It's raised from LSTMCell.apply() as far as I can tell.

It could be due to me upgrading (unexpectedly) to tensorflow 2.1 [they've merged tensorflow and tensorflow-gpu 2.0 into a single package and released so now anything which depended on tensorflow now installs 2.1 and enables gpu support - not a bad thing long term but causes a few short term issues]

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

PS Would you recommend to use scan inside a Module instead?

from flax.

avital avatar avital commented on April 27, 2024

Hi @david-waterworth -- can you please take a look at our seq2seq example and let us know if that help clarify how to use LSTMCell? https://github.com/google-research/flax/blob/prerelease/examples/seq2seq/train.py

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

@avital yes thanks, the use of nn.attention.scan_in_dim answers my question (although I wouldn't have expected to find it in the attention module - maybe flax.jax_utils?)

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

@avital there were a few of minor questions I had after adapting the seq2seq example to for my purposes.

  1. Why are you calling train_metrics = jax.device_get(metrics)? The annotated mnist example seems to imply this is automatic ("metrics are only retrieved from device when needed on host (like in this print statement)")?

  2. The example doesn't evaluate a test loss during training. Again from the mnist example I discovered optimizer.target and using this I was able to do what I needed, ideally the examples would include this as it's fairly standard pratice?

  3. For completeness should train_model not return the trained model? In the example you're not using it but it would still be useful, particularly as immutability means you have to return optimizer.target not model which might trap someone?

[part of my confusion was I deleted decode_batch(optimizer.target, 5) from train_model() as I adapted it to my regression / time-series problem - that seems to effectively be equivalent to a test set evaluation.]

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

One more thing that really tripped me up.

In your example the Encoder.apply() has a default argument hidden_size and you never explicitly set this. I wanted to use a different value, so I made the following changes

class Seq2seq(nn.Module):
  """Sequence-to-sequence class using encoder/decoder architecture."""

  def apply(self,
            encoder_inputs,
            decoder_inputs,
            hidden_size=128,  # <- added here (used different default to demonstrate)
            train=True,
            max_output_len=None):
    """Run the seq2seq model."""
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size, _, vocab_size = encoder_inputs.shape
    carry = Encoder(encoder_inputs, name='encoder', hidden_size=hidden_size)

Then I modified create_model() i,e,:

def create_model():
  """Creates a seq2seq model."""
  vocab_size = CTABLE.vocab_size()
  _, model = Seq2seq.create_by_shape(
      nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32),
                      ((1, get_max_output_len(), vocab_size), jnp.float32)], 
      hidden_size=64)   # <- modified here (again different to defaults)
  return model

The issue is Seq2seq.create_by_shape() passes hidden_size=64 but the returned model doesn't appear to have this parameter partially applied - so each time it is applied the default is used.

model = create_model()
model(batch['query'], batch['answer'])   # <- uses incorrect encoder hidden size?

What I had to do instead is create a partial Seq2seq.

def create_model():
  """Creates a seq2seq model."""
  vocab_size = CTABLE.vocab_size()
  model_def = Seq2seq.partial(hidden_size=64)
  _, model = model_def.create_by_shape(
      nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32),
                      ((1, get_max_output_len(), vocab_size), jnp.float32)])
  return model

Should create_by_shape partially apply *args and **kwargs?

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

Finally it seems a little awkward that initialize_carry is stochastic? By default it uses initializers.zeros but because it calls random.split() you cannot pass None. That means that at inference time you seem to have to do:

model = train_model()

rng = jax.random.PRNGKey(0)
with flax.nn.stochastic(rng):
    y_hat = model(x)

I simply used jnp.zeros instead to create the initial state instead.

from flax.

marcvanzee avatar marcvanzee commented on April 27, 2024

Hi David, thanks a lot for your comments! I'm also still relatively new to FLAX, so having someone taking a critical look at my code is definitely good for my understanding as well.

@avital yes thanks, the use of nn.attention.scan_in_dim answers my question (although I wouldn't have expected to find it in the attention module - maybe flax.jax_utils?)

I agree, we should move it to flax.jax_utils (not very high priority though)

Why are you calling train_metrics = jax.device_get(metrics)?

Hmm good point! I basically do this because this is done for the MNIST example, line 146, but I tested it and without that it also works. I will update the example.

The example doesn't evaluate a test loss during training. Again from the mnist example I discovered optimizer.target and using this I was able to do what I needed, ideally the examples would include this as it's fairly standard pratice?

Indeed, it doesn't report test loss. However the accuracy it reports is the same accuracy as during inference (without teacher forcing), which I explain in the comment on line 214: "Computes sequence accuracy, which is the same as the accuracy during inference, since teacher forcing is irrelevant when all output are correct."

I think adding test loss is nice but not crucial, since we report accuracy and some example decodings. Let me know if you feel otherwise, or whether you think my reasoning doesn't make sense.

For completeness should train_model not return the trained model? In the example you're not using it but it would still be useful, particularly as immutability means you have to return optimizer.target not model which might trap someone?

Yes that makes sense to me, I'll update it.

Should create_by_shape partially apply *args and **kwargs?

I think your solution of first using partial and then create_by_shape is fine, you could even concatenate them:

model_def = Seq2seq.partial(...).create_by_shape(...)

This is how I would do it as well, but if you find a better way, please let me know :-).

Finally it seems a little awkward that initialize_carry is stochastic?

I don't think I fully understand what you mean here. Could you clarify?

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

I think your solution of first using partial and then create_by_shape is fine,

What I've also started doing is to avoid having default values for apply *args like hidden_size. This way if I pass them to create_by_shape instead of partial I'll get an error later at train time. Otherwise it just silently creates with one value but applies using the default.

I don't think I fully understand what you mean here. Could you clarify?

What I mean by stochastic is if you do the following

model=train_model()
model(x)

It will throw an exception, you have to create a random context i.e.

model=train_model()
rng = jrandom.PRNGKey(0)

with nn.stochastic(rng):
  model(x)

This is because of:

carry = nn.LSTMCell.initialize_carry(nn.make_rng(), (batch_size,),

Since whilst it's not used all the initialisers require an rng even the constant ones (i.e. initializers.zeros). It's not a big deal but I'm not sure there's a need to be able to randomly initialise the state which is why I simply replaced this line with jnp.zeros.

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

PS If you make hidden_size a flag (which if I understand absl correctly turns it into a command line parameter) and then modify the code so it's correctly passed to the Encoder you can demonstrate the need to use partial

from flax.

jheek avatar jheek commented on April 27, 2024

I think it would be better to use a fixed seed to initialize the carry given that it is just zeros anyway. It seems odd to require nn.stochastic for something that actually isn't random. I created PR #85 for this

from flax.

marcosrdac avatar marcosrdac commented on April 27, 2024

A Model instance currently fails when passed into a function like jit or scan. I have a fix for this but for now you could use scan(lambda c, x: lstm(c, x), c0, timeseries).

I plan to use Flax in my research on RNNs, but I'm struggling to understand some ideas behind Flax implementation for days. I really wanted to see an example of someone using the documented RNN cells, it would help me a lot!

Anyway, I tried to use the lambda trick, but it ended up not working, I get an error like this: ValueError: Jax transforms and modules cannot be mixed. when the cell is being defined. The definitions are like bellow (I don't know if that is that way it is mean't to be done, but I made it to be run on a single sample (i.e. x is a vector)):

class LRNNCell(nn.Module):
    @nn.compact
    def __call__(self, h, x):
        nh = h.shape[0]
        Whx = nn.Dense(nh)
        Whh = nn.Dense(nh, use_bias=False)
        Wyh = nn.Dense(1)

        h = nn.tanh(Whx(x) + Whh(h))
        y = nn.tanh(Wyh(h))
        return h, y

class LRNN(nn.Module):
    ny: Any
    nh: Any

    @nn.compact
    def __call__(self, x):
        h = jnp.zeros(self.nh)
        cell = LRNNCell()

        h, y = jax.lax.scan(lambda h, x: cell(h, x), h, x)
        return y[-self.ny:]

What am I missing?

from flax.

marcvanzee avatar marcvanzee commented on April 27, 2024

Hi @marcosrdac! I've copied your question to Github Discussion #1283 so it is easier accessible for other users. Let's continue the discussion there!

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.