Comments (18)
I think both your suggestions are good insights. I guess the stochastic initialize_carry is a bit less crucial (it is a most a bit more cumbersome in the current way), but the initialization problem seems like something users can run into more often. I will ask around how people usually handle this, and see if this can be improved, or if we can have clearer guidelines around how to best initialize such variables. I guess at least we should have some example. I'll let you know when I have an answer.
from flax.
I'm closing this issue since most important questions seem to have been answered. David: if you have any other specific concerns / questions, can you please open a new issue? Thanks!
from flax.
Indeed there is a mistake in the example and the test is correct.
We also have a full LSTM example in the pipeline which should be merged soon.
As for your use case which if I understand correctly is the batch size 1:
I think you could try to initialize the carry with initialize_carry(rng, (1,), 5)
and it should probably work when using initialize_carry(rng, (), 5)
as well.
from flax.
Thanks, you're correct both of those methods work for initialize_carry
I'm having issues with scan though. If I use the pure python definition of scan from the jax docs the code below works. But if I use the jax implementation I get an exception which I've quoted below.
I'm using jax 0.1.59 and jaxlib 0.1.39 with tensorflow 2.1.0
import jax
import flax
import numpy as onp
import jax.numpy as jnp
rng = jax.random.PRNGKey(0)
x = jnp.ones((365, 96, 2))
c0 = flax.nn.recurrent.LSTMCell.initialize_carry(rng, (), 2)
print(f"c0: {c0}")
(c, y), lstm = flax.nn.recurrent.LSTMCell.create(rng, c0, x[-1][0])
print(f"c: {c}")
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, jnp.stack(ys)
timeseries = x[0,:,:]
scan(lstm, c0, timeseries)
scan = jax.lax.scan
scan(lstm, c0, timeseries)
Traceback (most recent call last):
File "/home/david/dev/flax/examples/rnn/train.py", line 29, in
scan(lstm, c0, timeseries)
File "/home/david/.pyenv/versions/3.7.1/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 791, in scan
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
File "", line 2, in hash
TypeError: unhashable type: 'dict'
from flax.
A Model instance currently fails when passed into a function like jit or scan. I have a fix for this but for now you could use scan(lambda c, x: lstm(c, x), c0, timeseries)
.
The reason this fails is because Jax tries to cache the functions it transforms but Model has parameters arrays which are not hashable. I will implement hash in a separate change and at a test for this.
from flax.
Thanks yes that works.
PS I'm also seeing this XLA warning
2020-02-19 08:51:51.429909: W external/org_tensorflow/tensorflow/compiler/xla/service/hlo_pass_fix.h:49] Unexpectedly high number of iterations in HLO passes, exiting fixed point loop.
It's raised from LSTMCell.apply() as far as I can tell.
It could be due to me upgrading (unexpectedly) to tensorflow 2.1 [they've merged tensorflow and tensorflow-gpu 2.0 into a single package and released so now anything which depended on tensorflow now installs 2.1 and enables gpu support - not a bad thing long term but causes a few short term issues]
from flax.
PS Would you recommend to use scan inside a Module instead?
from flax.
Hi @david-waterworth -- can you please take a look at our seq2seq example and let us know if that help clarify how to use LSTMCell? https://github.com/google-research/flax/blob/prerelease/examples/seq2seq/train.py
from flax.
@avital yes thanks, the use of nn.attention.scan_in_dim
answers my question (although I wouldn't have expected to find it in the attention module - maybe flax.jax_utils?)
from flax.
@avital there were a few of minor questions I had after adapting the seq2seq example to for my purposes.
-
Why are you calling
train_metrics = jax.device_get(metrics)
? The annotated mnist example seems to imply this is automatic ("metrics are only retrieved from device when needed on host (like in this print statement)")? -
The example doesn't evaluate a test loss during training. Again from the mnist example I discovered
optimizer.target
and using this I was able to do what I needed, ideally the examples would include this as it's fairly standard pratice? -
For completeness should
train_model
not return the trained model? In the example you're not using it but it would still be useful, particularly as immutability means you have to returnoptimizer.target
notmodel
which might trap someone?
[part of my confusion was I deleted decode_batch(optimizer.target, 5)
from train_model()
as I adapted it to my regression / time-series problem - that seems to effectively be equivalent to a test set evaluation.]
from flax.
One more thing that really tripped me up.
In your example the Encoder.apply()
has a default argument hidden_size and you never explicitly set this. I wanted to use a different value, so I made the following changes
class Seq2seq(nn.Module):
"""Sequence-to-sequence class using encoder/decoder architecture."""
def apply(self,
encoder_inputs,
decoder_inputs,
hidden_size=128, # <- added here (used different default to demonstrate)
train=True,
max_output_len=None):
"""Run the seq2seq model."""
# inputs.shape = (batch_size, seq_length, vocab_size).
batch_size, _, vocab_size = encoder_inputs.shape
carry = Encoder(encoder_inputs, name='encoder', hidden_size=hidden_size)
Then I modified create_model() i,e,:
def create_model():
"""Creates a seq2seq model."""
vocab_size = CTABLE.vocab_size()
_, model = Seq2seq.create_by_shape(
nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32),
((1, get_max_output_len(), vocab_size), jnp.float32)],
hidden_size=64) # <- modified here (again different to defaults)
return model
The issue is Seq2seq.create_by_shape()
passes hidden_size=64
but the returned model
doesn't appear to have this parameter partially applied - so each time it is applied
the default is used.
model = create_model()
model(batch['query'], batch['answer']) # <- uses incorrect encoder hidden size?
What I had to do instead is create a partial Seq2seq.
def create_model():
"""Creates a seq2seq model."""
vocab_size = CTABLE.vocab_size()
model_def = Seq2seq.partial(hidden_size=64)
_, model = model_def.create_by_shape(
nn.make_rng(), [((1, get_max_input_len(), vocab_size), jnp.float32),
((1, get_max_output_len(), vocab_size), jnp.float32)])
return model
Should create_by_shape
partially apply *args and **kwargs?
from flax.
Finally it seems a little awkward that initialize_carry
is stochastic? By default it uses initializers.zeros but because it calls random.split() you cannot pass None. That means that at inference time you seem to have to do:
model = train_model()
rng = jax.random.PRNGKey(0)
with flax.nn.stochastic(rng):
y_hat = model(x)
I simply used jnp.zeros instead to create the initial state instead.
from flax.
Hi David, thanks a lot for your comments! I'm also still relatively new to FLAX, so having someone taking a critical look at my code is definitely good for my understanding as well.
@avital yes thanks, the use of
nn.attention.scan_in_dim
answers my question (although I wouldn't have expected to find it in the attention module - maybe flax.jax_utils?)
I agree, we should move it to flax.jax_utils (not very high priority though)
Why are you calling train_metrics = jax.device_get(metrics)?
Hmm good point! I basically do this because this is done for the MNIST example, line 146, but I tested it and without that it also works. I will update the example.
The example doesn't evaluate a test loss during training. Again from the mnist example I discovered optimizer.target and using this I was able to do what I needed, ideally the examples would include this as it's fairly standard pratice?
Indeed, it doesn't report test loss. However the accuracy it reports is the same accuracy as during inference (without teacher forcing), which I explain in the comment on line 214: "Computes sequence accuracy, which is the same as the accuracy during inference, since teacher forcing is irrelevant when all output are correct."
I think adding test loss is nice but not crucial, since we report accuracy and some example decodings. Let me know if you feel otherwise, or whether you think my reasoning doesn't make sense.
For completeness should train_model not return the trained model? In the example you're not using it but it would still be useful, particularly as immutability means you have to return optimizer.target not model which might trap someone?
Yes that makes sense to me, I'll update it.
Should create_by_shape partially apply *args and **kwargs?
I think your solution of first using partial and then create_by_shape is fine, you could even concatenate them:
model_def = Seq2seq.partial(...).create_by_shape(...)
This is how I would do it as well, but if you find a better way, please let me know :-).
Finally it seems a little awkward that initialize_carry is stochastic?
I don't think I fully understand what you mean here. Could you clarify?
from flax.
I think your solution of first using partial and then create_by_shape is fine,
What I've also started doing is to avoid having default values for apply *args like hidden_size
. This way if I pass them to create_by_shape
instead of partial
I'll get an error later at train time. Otherwise it just silently creates with one value but applies using the default.
I don't think I fully understand what you mean here. Could you clarify?
What I mean by stochastic is if you do the following
model=train_model()
model(x)
It will throw an exception, you have to create a random context i.e.
model=train_model()
rng = jrandom.PRNGKey(0)
with nn.stochastic(rng):
model(x)
This is because of:
flax/examples/seq2seq/train.py
Line 117 in 1022b9e
Since whilst it's not used all the initialisers require an rng even the constant ones (i.e. initializers.zeros
). It's not a big deal but I'm not sure there's a need to be able to randomly initialise the state which is why I simply replaced this line with jnp.zeros.
from flax.
PS If you make hidden_size a flag (which if I understand absl correctly turns it into a command line parameter) and then modify the code so it's correctly passed to the Encoder you can demonstrate the need to use partial
from flax.
I think it would be better to use a fixed seed to initialize the carry given that it is just zeros anyway. It seems odd to require nn.stochastic
for something that actually isn't random. I created PR #85 for this
from flax.
A Model instance currently fails when passed into a function like jit or scan. I have a fix for this but for now you could use
scan(lambda c, x: lstm(c, x), c0, timeseries)
.
I plan to use Flax in my research on RNNs, but I'm struggling to understand some ideas behind Flax implementation for days. I really wanted to see an example of someone using the documented RNN cells, it would help me a lot!
Anyway, I tried to use the lambda trick, but it ended up not working, I get an error like this: ValueError: Jax transforms and modules cannot be mixed.
when the cell is being defined. The definitions are like bellow (I don't know if that is that way it is mean't to be done, but I made it to be run on a single sample (i.e. x is a vector)):
class LRNNCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nh = h.shape[0]
Whx = nn.Dense(nh)
Whh = nn.Dense(nh, use_bias=False)
Wyh = nn.Dense(1)
h = nn.tanh(Whx(x) + Whh(h))
y = nn.tanh(Wyh(h))
return h, y
class LRNN(nn.Module):
ny: Any
nh: Any
@nn.compact
def __call__(self, x):
h = jnp.zeros(self.nh)
cell = LRNNCell()
h, y = jax.lax.scan(lambda h, x: cell(h, x), h, x)
return y[-self.ny:]
What am I missing?
from flax.
Hi @marcosrdac! I've copied your question to Github Discussion #1283 so it is easier accessible for other users. Let's continue the discussion there!
from flax.
Related Issues (20)
- Make `self.make_rng()` callable by using a default RNG stream and have `.init()` and `.apply()` use the default RNG stream if no explicit RNG mapping is passed.
- DenseGeneral with more than 2 dimensions cannot be partitioned HOT 1
- asyncio error while loading weights HOT 4
- Add vanilla / Elman / simple RNN cell HOT 4
- NNX `_compute_stats` function missing `use_fast_variance` and `mask` argument HOT 1
- Memory issue when randomly initializing large parameters, sharding cannot help
- Deprecation Warnings with orbax 0.5.3 HOT 2
- Feature request: Add ConvGRUCell
- The Error in FLOP Computation of Model Tabulate Function HOT 1
- Make redundant `features` argument optional for recurrent cells HOT 2
- Add `reset_gate` flag to `MGUCell` HOT 4
- Unify behavior of strides arg of Conv and ConvTranspose HOT 1
- modifying params of flax.linen. Module model HOT 1
- Error when calling module tabulate involving WeightNorm HOT 5
- Compatibility with Torch LSTM HOT 2
- Wrong parameter names when nesting Modules within flax transformations HOT 3
- Neural Net Training is bottlenecked by maxed out CPU
- Error occurs in `nn.vmap` while `variable_axes` is a nested dict HOT 1
- flax.linen.module.init still fails under dynamic type checking for nested modules
- *Module Parameters* section of docs is outdated. HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.