Git Product home page Git Product logo

Comments (7)

friesel avatar friesel commented on May 3, 2024

Hmmm. Possibly this is not an n_devices-in-decode-issue but a bug in Transformer() that appears in beam-search-settings as I can replicate the bug from my local system (GPU only) also in your colab if I upload a Tranformer-model (rather than a Reformer-model)

from trax.

stefan-falk avatar stefan-falk commented on May 3, 2024

@friesel I am getting the same now:

Running:

import os
import trax
import gin
import numpy as np
from tensorflow_datasets.core.features.text import SubwordTextEncoder
from trax.models.beam_search import Search


def main():

    output_dir = '/tmp/trax/model'
    vocab_fp = os.path.join(output_dir, 'vocab')

    text_encoder = SubwordTextEncoder.load_from_file(vocab_fp)

    gin.parse_config_file('examples/transformer.cfg')
    gin.bind_parameter('Transformer.input_vocab_size', text_encoder.vocab_size)

    model_infer = trax.models.Transformer(mode='predict')
    model_infer.init_from_file('/tmp/trax/model/model.pkl', weights_only=True)

    sampling_decoder = Search(
        trax.models.Transformer,
        model_infer.weights,
        temperature=1.0,
        max_decode_len=32 * 64 * 3,
    )

    inputs = np.asarray(
        [
            [4, 5, 6, 1],
            [5, 7, 8, 1]
        ]
    )

    preds, scores = sampling_decoder.decode(inputs)

    print('All done.')


if __name__ == '__main__':
    main()

Will throw

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 1) for shape (1, 2, 4, 1).
Click to expand entire error message
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/models/transformer.py, line 299
  layer input shapes: (ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
    weights_or_empty, state = sublayer.init(inputs)

LayerError: Exception passing through layer Branch (in init):
  layer created in file [...]/trax/layers/combinators.py, line 472
  layer input shapes: ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}

  File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer Parallel (in _forward_abstract):
  layer created in file [...]/trax/layers/combinators.py, line 470
  layer input shapes: (ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}, ShapeDtype{shape:(1, 2, 4, 1), dtype:int32})

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 1602, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 326, in abstract_eval_fun
    instantiate=True, stage_out=True)

  File [...]/jax/interpreters/partial_eval.py, line 421, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 479, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/combinators.py, line 239, in forward_with_state
    sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

LayerError: Exception passing through layer PaddingMask (in pure_fn):
  layer created in file [...]/trax/models/transformer.py, line 284
  layer input shapes: ShapeDtype{shape:(1, 2, 4, 1), dtype:int32}

  File [...]/trax/layers/base.py, line 222, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 582, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 55, in PaddingMask
    return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

  File [...]/jax/numpy/lax_numpy.py, line 1016, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays

  File [...]/jax/numpy/lax_numpy.py, line 1054, in _reshape_method
    return _reshape(a, newshape, order=order)

  File [...]/jax/numpy/lax_numpy.py, line 1033, in _reshape
    return lax.reshape(a, computed_newshape, None)

  File [...]/jax/lax/lax.py, line 675, in reshape
    dimensions=None if dimensions is None or same_dims else tuple(dimensions))

  File [...]/site-packages/jax/core.py, line 202, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)

  File [...]/jax/interpreters/partial_eval.py, line 133, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 141, in default_process_primitive
    out_aval = primitive.abstract_eval(*avals, **params)

  File [...]/jax/lax/lax.py, line 1672, in standard_abstract_eval
    return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

  File [...]/jax/lax/lax.py, line 2856, in _reshape_shape_rule
    raise TypeError(msg.format(new_sizes, onp.shape(operand)))

TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 1) for shape (1, 2, 4, 1).

Process finished with exit code 1

However, I am able to get passt this error when presenting an input like this:

    inputs = np.asarray(
        [[
            [4, 5, 6, 1]
        ]]
    )

In this case I get predictions but there will be another crash right after that in

seqs = seqs[:, :, 1:] # Strip start token

image

However, those predictions probably not what I want anyway.

from trax.

stefan-falk avatar stefan-falk commented on May 3, 2024

Update

I noticied that Search has an argument eos_id.

    sampling_decoder = Search(
        trax.models.Transformer,
        model_infer.weights,
        temperature=1.0,
        max_decode_len=32 * 64 * 3,
        eos_id=EOS_ID
    )

Setting this to the actual one that I am using, 1 instead of -1 (default), I get through the call of self._jit_beam_search but then end up with another error that is caused by

seqs = seqs[:, :, 1:] # Strip start token

because seq, at this point, is two dimensional. Leaving me at:

    seqs = seqs[:, :, 1:]  # Strip start token
IndexError: too many indices for array

from trax.

stefan-falk avatar stefan-falk commented on May 3, 2024

I installed from the latest commit (0294404) and can confirm that my small example runs through:

sampling_decoder = Search(
    trax.models.Transformer,
    model_infer.weights,
    beam_size=1,
    alpha=0.6,
    max_decode_len=32 * 64 * 3,
    eos_id=EOS_ID
)

example_inputs = text_encoder.encode('versta') + [EOS_ID]
inputs = np.asarray([example_inputs])
preds, scores = sampling_decoder.decode(inputs)

print(preds)
print('All done.')

Giving me something like

[[[8 8 8 ... 8 8 8]]]
All done.

The output sequence looks a bit weird but that might be on me.

However, I am still not able to do something like this:

more_preds = model_infer((inputs, np.ones(shape=inputs.shape)))
print(more_preds)
print('All done. (2)')

which throws:

LayerError: Exception passing through layer Cache (in pure_fn):
  layer created in file [...]/trax/models/transformer.py, line 270
  layer input shapes: (ShapeDtype{shape:(1, 2), dtype:int64}, ShapeDtype{shape:(1, 1, 1, 2), dtype:bool})

  File [...]/trax/layers/combinators.py, line 671, in forward_with_state
    if state[0] is ():  # pylint: disable=literal-comparison

IndexError: tuple index out of range

from trax.

DevKretov avatar DevKretov commented on May 3, 2024

I tried to reproduce your code and get this error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/content/trax/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    446         outputs, s = (
--> 447             self.forward_with_state(x, weights=weights, state=state, rng=rng))
    448       else:

14 frames
/content/trax/trax/layers/combinators.py in forward_with_state(self, xs, weights, state, rng)
     67       raise ValueError(
---> 68           f'Number of weight elements ({len(weights)}) does not equal '
     69           f'number of sublayers ({n_layers}).')

ValueError: Number of weight elements (16) does not equal number of sublayers (31).

The above exception was the direct cause of the following exception:

LayerError                                Traceback (most recent call last)
<ipython-input-45-43120f5a1f8f> in <module>()
      1 
----> 2 preds, scores = sampling_decoder.decode(inputs)
      3 
      4 print(preds)
      5 print('All done.')

/content/trax/trax/models/beam_search.py in decode(self, inputs, targets_prefix, batch_size)
    593     seqs, scores = self._jit_beam_search(
    594         inputs, targets_prefix, (batch_size + pad_amount) // n_devices,
--> 595         dummy=np.zeros(n_devices))
    596     seqs = onp.asarray(seqs)
    597     scores = onp.asarray(scores)

/usr/local/lib/python3.6/dist-packages/jax/api.py in f_pmapped(*args, **kwargs)
    930         global_axis_size=axis_size,
    931         devices=tuple(devices) if devices is not None else devices,
--> 932         name=flat_fun.__name__)
    933     return tree_unflatten(out_tree(), out)
    934 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    893   if top_trace is None:
    894     with new_sublevel():
--> 895       outs = primitive.impl(f, *args, **params)
    896   else:
    897     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, mapped_invars, *args)
    404   abstract_args = map(xla.abstractify, args)
    405   compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
--> 406                                    global_axis_size, devices, name, *abstract_args)
    407   return compiled_fun(*args)
    408 

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    218       fun.populate_stores(stores)
    219     else:
--> 220       ans = call(fun, *args)
    221       cache[key] = (ans, fun.stores)
    222     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, devices, name, *avals)
    460   pval = pe.PartialVal([core.abstract_unit, core.unit])  # dummy value for axis env
    461   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
--> 462       dynamic_fun, [pval] + pvals, instantiate=False, stage_out_calls=True, bottom=True)
    463   jaxpr.invars = jaxpr.invars[1:]  # ignore dummy
    464 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls, bottom)
    356   with new_master(trace_type, bottom=bottom) as master:
    357     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 358     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    359     assert not env
    360     del master

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    147     gen = None
    148 
--> 149     ans = self.f(*args, **dict(self.params, **kwargs))
    150     del args
    151     while stack:

/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in dynamic_fun(dummy, *args)
    453   def dynamic_fun(dummy, *args):
    454     with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size):
--> 455       return fun.call_wrapped(*args)
    456 
    457   avals = tuple(map(partial(shard_aval, axis_size), avals))

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    147     gen = None
    148 
--> 149     ans = self.f(*args, **dict(self.params, **kwargs))
    150     del args
    151     while stack:

/content/trax/trax/models/beam_search.py in _unreplicated_beam_search(***failed resolving arguments***)
    526     return beam_search(
    527         batch_size,
--> 528         self._get_initial_state(inputs, targets_prefix, batch_size),
    529         tokens_to_logits,
    530         max_decode_len,

/content/trax/trax/models/beam_search.py in _get_initial_state(self, inputs, targets_prefix, batch_size)
    487         self.model_weights,
    488         initial_state,
--> 489         jax.random.PRNGKey(0))
    490     state_structure = jax.tree_structure(prompted_state)
    491 

/content/trax/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    454       name, trace = self._name, _short_traceback()
    455       raise LayerError(name, 'pure_fn',
--> 456                        self._caller, signature(x), trace) from e
    457 
    458   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/models/transformer.py, line 299
  layer input shapes: (ShapeDtype{shape:(1, 16), dtype:int32}, ShapeDtype{shape:(1, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 68, in forward_with_state
    f'Number of weight elements ({len(weights)}) does not equal '

ValueError: Number of weight elements (16) does not equal number of sublayers (31).

from trax.

DevKretov avatar DevKretov commented on May 3, 2024

I've understood, why I'd had the problem mentioned above. When you create your Search object, you have to pass all the parameters from your model to gin config. Otherwise, Search will initialise the model from the first argument with default parameters which may not match.

However, I'm getting the output consisting of just constant values, which are for me <START_OF_SEQ> value.

Any idea why this may happen?

from trax.

DevKretov avatar DevKretov commented on May 3, 2024

The only way how the inference with the help of Transformer can be done, I think, is

# DO NOT FORGET TO PARSE GIN CONFIG WITH THE HYPERPARAMETERS YOU HAD DURING TRAINING
# otherwise, you will get layers error (number incopatibilities) and non-filled parameters of
# function calls 
model_infer_1 = trax.models.Transformer(mode='train')

model_infer_1.init_from_file(
    'gs://your/path/to/model.pkl', weights_only=False)

inputs = np.array([[vocab['<SOS>'], 1, 2, 3, 4, 5, 6, 7, 8, 9, vocab['<EOS>']]])

targets_prefix = np.array([[vocab['<SOS>'], 0]])

for char_i in range(inputs.shape[1]):
  ret = model_infer_1((inputs, targets_prefix))
  logits = ret[0]
  batch = logits[0]

  new_decoded_char = np.argmax(batch[-1], -1)

  if new_decoded_char == vocab['<EOS>']:
    output = targets_prefix[:, :-1]
    break

  # Take all previous chars except for the last one used to make a prediction,
  # then add the predicted one and new zero character to indicate that the new
  # character has to be searched
  targets_prefix = np.hstack((targets_prefix[:, :-1], np.array([[new_decoded_char, 0]],)))

  print(f'{targets_prefix}')

This code returns correct results. Nevertheless, the same model is not capable of running in Beam Search wrapper (returns characters for the whole decoded string).

from trax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.