Comments (7)
Hmmm. Possibly this is not an n_devices-in-decode-issue but a bug in Transformer() that appears in beam-search-settings as I can replicate the bug from my local system (GPU only) also in your colab if I upload a Tranformer-model (rather than a Reformer-model)
from trax.
@friesel I am getting the same now:
Running:
import os
import trax
import gin
import numpy as np
from tensorflow_datasets.core.features.text import SubwordTextEncoder
from trax.models.beam_search import Search
def main():
output_dir = '/tmp/trax/model'
vocab_fp = os.path.join(output_dir, 'vocab')
text_encoder = SubwordTextEncoder.load_from_file(vocab_fp)
gin.parse_config_file('examples/transformer.cfg')
gin.bind_parameter('Transformer.input_vocab_size', text_encoder.vocab_size)
model_infer = trax.models.Transformer(mode='predict')
model_infer.init_from_file('/tmp/trax/model/model.pkl', weights_only=True)
sampling_decoder = Search(
trax.models.Transformer,
model_infer.weights,
temperature=1.0,
max_decode_len=32 * 64 * 3,
)
inputs = np.asarray(
[
[4, 5, 6, 1],
[5, 7, 8, 1]
]
)
preds, scores = sampling_decoder.decode(inputs)
print('All done.')
if __name__ == '__main__':
main()
Will throw
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 1) for shape (1, 2, 4, 1).
Click to expand entire error message
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 299
layer input shapes: (ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
LayerError: Exception passing through layer Branch (in init):
layer created in file [...]/trax/layers/combinators.py, line 472
layer input shapes: ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 470
layer input shapes: (ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}, ShapeDtype{shape:(1, 2, 4, 1), dtype:int32})
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 1602, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 326, in abstract_eval_fun
instantiate=True, stage_out=True)
File [...]/jax/interpreters/partial_eval.py, line 421, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 479, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/combinators.py, line 239, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 284
layer input shapes: ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}
File [...]/trax/layers/base.py, line 222, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 582, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 55, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File [...]/jax/numpy/lax_numpy.py, line 1016, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File [...]/jax/numpy/lax_numpy.py, line 1054, in _reshape_method
return _reshape(a, newshape, order=order)
File [...]/jax/numpy/lax_numpy.py, line 1033, in _reshape
return lax.reshape(a, computed_newshape, None)
File [...]/jax/lax/lax.py, line 675, in reshape
dimensions=None if dimensions is None or same_dims else tuple(dimensions))
File [...]/site-packages/jax/core.py, line 202, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File [...]/jax/interpreters/partial_eval.py, line 133, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 141, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File [...]/jax/lax/lax.py, line 1672, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File [...]/jax/lax/lax.py, line 2856, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 1) for shape (1, 2, 4, 1).
Process finished with exit code 1
However, I am able to get passt this error when presenting an input like this:
inputs = np.asarray(
[[
[4, 5, 6, 1]
]]
)
In this case I get predictions but there will be another crash right after that in
trax/trax/models/beam_search.py
Line 602 in 5b15659
However, those predictions probably not what I want anyway.
from trax.
Update
I noticied that Search
has an argument eos_id
.
sampling_decoder = Search(
trax.models.Transformer,
model_infer.weights,
temperature=1.0,
max_decode_len=32 * 64 * 3,
eos_id=EOS_ID
)
Setting this to the actual one that I am using, 1
instead of -1
(default), I get through the call of self._jit_beam_search
but then end up with another error that is caused by
trax/trax/models/beam_search.py
Line 602 in 5b15659
because seq
, at this point, is two dimensional. Leaving me at:
seqs = seqs[:, :, 1:] # Strip start token
IndexError: too many indices for array
from trax.
I installed from the latest commit (0294404) and can confirm that my small example runs through:
sampling_decoder = Search(
trax.models.Transformer,
model_infer.weights,
beam_size=1,
alpha=0.6,
max_decode_len=32 * 64 * 3,
eos_id=EOS_ID
)
example_inputs = text_encoder.encode('versta') + [EOS_ID]
inputs = np.asarray([example_inputs])
preds, scores = sampling_decoder.decode(inputs)
print(preds)
print('All done.')
Giving me something like
[[[8 8 8 ... 8 8 8]]]
All done.
The output sequence looks a bit weird but that might be on me.
However, I am still not able to do something like this:
more_preds = model_infer((inputs, np.ones(shape=inputs.shape)))
print(more_preds)
print('All done. (2)')
which throws:
LayerError: Exception passing through layer Cache (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 270
layer input shapes: (ShapeDtype{shape:(1, 2), dtype:int64}, ShapeDtype{shape:(1, 1, 1, 2), dtype:bool})
File [...]/trax/layers/combinators.py, line 671, in forward_with_state
if state[0] is (): # pylint: disable=literal-comparison
IndexError: tuple index out of range
from trax.
I tried to reproduce your code and get this error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/content/trax/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
446 outputs, s = (
--> 447 self.forward_with_state(x, weights=weights, state=state, rng=rng))
448 else:
14 frames
/content/trax/trax/layers/combinators.py in forward_with_state(self, xs, weights, state, rng)
67 raise ValueError(
---> 68 f'Number of weight elements ({len(weights)}) does not equal '
69 f'number of sublayers ({n_layers}).')
ValueError: Number of weight elements (16) does not equal number of sublayers (31).
The above exception was the direct cause of the following exception:
LayerError Traceback (most recent call last)
<ipython-input-45-43120f5a1f8f> in <module>()
1
----> 2 preds, scores = sampling_decoder.decode(inputs)
3
4 print(preds)
5 print('All done.')
/content/trax/trax/models/beam_search.py in decode(self, inputs, targets_prefix, batch_size)
593 seqs, scores = self._jit_beam_search(
594 inputs, targets_prefix, (batch_size + pad_amount) // n_devices,
--> 595 dummy=np.zeros(n_devices))
596 seqs = onp.asarray(seqs)
597 scores = onp.asarray(scores)
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_pmapped(*args, **kwargs)
930 global_axis_size=axis_size,
931 devices=tuple(devices) if devices is not None else devices,
--> 932 name=flat_fun.__name__)
933 return tree_unflatten(out_tree(), out)
934
/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
893 if top_trace is None:
894 with new_sublevel():
--> 895 outs = primitive.impl(f, *args, **params)
896 else:
897 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, mapped_invars, *args)
404 abstract_args = map(xla.abstractify, args)
405 compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
--> 406 global_axis_size, devices, name, *abstract_args)
407 return compiled_fun(*args)
408
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
218 fun.populate_stores(stores)
219 else:
--> 220 ans = call(fun, *args)
221 cache[key] = (ans, fun.stores)
222 return ans
/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, devices, name, *avals)
460 pval = pe.PartialVal([core.abstract_unit, core.unit]) # dummy value for axis env
461 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
--> 462 dynamic_fun, [pval] + pvals, instantiate=False, stage_out_calls=True, bottom=True)
463 jaxpr.invars = jaxpr.invars[1:] # ignore dummy
464
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls, bottom)
356 with new_master(trace_type, bottom=bottom) as master:
357 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 358 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
359 assert not env
360 del master
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
147 gen = None
148
--> 149 ans = self.f(*args, **dict(self.params, **kwargs))
150 del args
151 while stack:
/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in dynamic_fun(dummy, *args)
453 def dynamic_fun(dummy, *args):
454 with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size):
--> 455 return fun.call_wrapped(*args)
456
457 avals = tuple(map(partial(shard_aval, axis_size), avals))
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
147 gen = None
148
--> 149 ans = self.f(*args, **dict(self.params, **kwargs))
150 del args
151 while stack:
/content/trax/trax/models/beam_search.py in _unreplicated_beam_search(***failed resolving arguments***)
526 return beam_search(
527 batch_size,
--> 528 self._get_initial_state(inputs, targets_prefix, batch_size),
529 tokens_to_logits,
530 max_decode_len,
/content/trax/trax/models/beam_search.py in _get_initial_state(self, inputs, targets_prefix, batch_size)
487 self.model_weights,
488 initial_state,
--> 489 jax.random.PRNGKey(0))
490 state_structure = jax.tree_structure(prompted_state)
491
/content/trax/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
454 name, trace = self._name, _short_traceback()
455 raise LayerError(name, 'pure_fn',
--> 456 self._caller, signature(x), trace) from e
457
458 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 299
layer input shapes: (ShapeDtype{shape:(1, 16), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 68, in forward_with_state
f'Number of weight elements ({len(weights)}) does not equal '
ValueError: Number of weight elements (16) does not equal number of sublayers (31).
from trax.
I've understood, why I'd had the problem mentioned above. When you create your Search object, you have to pass all the parameters from your model to gin config. Otherwise, Search will initialise the model from the first argument with default parameters which may not match.
However, I'm getting the output consisting of just constant values, which are for me <START_OF_SEQ> value.
Any idea why this may happen?
from trax.
The only way how the inference with the help of Transformer can be done, I think, is
# DO NOT FORGET TO PARSE GIN CONFIG WITH THE HYPERPARAMETERS YOU HAD DURING TRAINING
# otherwise, you will get layers error (number incopatibilities) and non-filled parameters of
# function calls
model_infer_1 = trax.models.Transformer(mode='train')
model_infer_1.init_from_file(
'gs://your/path/to/model.pkl', weights_only=False)
inputs = np.array([[vocab['<SOS>'], 1, 2, 3, 4, 5, 6, 7, 8, 9, vocab['<EOS>']]])
targets_prefix = np.array([[vocab['<SOS>'], 0]])
for char_i in range(inputs.shape[1]):
ret = model_infer_1((inputs, targets_prefix))
logits = ret[0]
batch = logits[0]
new_decoded_char = np.argmax(batch[-1], -1)
if new_decoded_char == vocab['<EOS>']:
output = targets_prefix[:, :-1]
break
# Take all previous chars except for the last one used to make a prediction,
# then add the predicted one and new zero character to indicate that the new
# character has to be searched
targets_prefix = np.hstack((targets_prefix[:, :-1], np.array([[new_decoded_char, 0]],)))
print(f'{targets_prefix}')
This code returns correct results. Nevertheless, the same model is not capable of running in Beam Search wrapper (returns characters for the whole decoded string).
from trax.
Related Issues (20)
- The colab button on Knowledge_Tracing_Transformer.ipynb is not open
- TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
- Machine Translation Refromer model.pkl for trax 1.4.1?
- ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'
- how to use trax to translate other languages
- Limit the dataset from TFDS
- TypeError: unsupported operand type(s) for ==: 'Array' and 'tuple' HOT 2
- SelfAttention - problem with tensorflow 2.11.0
- AttributeError: module 'jax.ops' has no attribute 'index_add' HOT 1
- Unable to import trax HOT 1
- Cannot import Trax HOT 6
- Could not normally run trax using GPU in local computer
- Issue when running training_loop.run(2000) - message StopIteration in next_batch(self)
- Is possible Linformer algorithm ?
- Can I do simple tokenization?
- Can't run `bert_vocab_from_dataset` without `TypeError: Tensor is unhashable` when import `trax` with `tensorflow`
- Are any easy ways to use something like `train_test_split` from `sklearn`?
- AttributeError: 'function' object has no attribute 'n_steps_per_checkpoint' for NLP Machine translation model HOT 1
- Error loading loop from a checkpoint HOT 1
- Inconsistency in function's doc-string HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trax.