Git Product home page Git Product logo

keras-utility-layer-collection's Issues

not support masking

I used:
encoder_input = ks.layers.Input(shape=(90,))
embed = Embedding(input_dim=598, output_dim=512, input_length=90, mask_zero=True)
encoder_inputs = embed(encoder_input)
……
I tried ‘’SequenceAttention‘’ and ‘’AttentionRNNWrapper‘’, then both shows"Layer does not support masking……"

version miss match

I am try to use
encoder = GRU(embedding_size, return_sequences=True, return_state=True, recurrent_dropout=0.1)
attented_encoder = ExternalAttentionRNNWrapper(encoder, return_attention=True)
but got error in wrapper class
super(ExternalAttentionRNNWrapper, self).init(layer, **kwargs)

File "C:\Users\NITS\Anaconda3\lib\site-packages\tensorflow\python\keras\layers\wrappers.py", line 52, in init
assert isinstance(layer, Layer)

AssertionError

print(type(encoder))
keras.layers.recurrent.LSTM but in Wrapper

assert isinstance(layer, Layer) error
self.layer = layer
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
self._input_map = {}
super(Wrapper, self).init(**kwargs)

AttentionWrapper States Tuple Out of Bounds (possibly a versioning issue)

Hello,

It seems as if you may have developed this on a version of Keras with a different API, as when I try to use this, I get an error where you try to access the third element of the state tuple on Keras 2.2.0:

  File "model.py", line 485, in <module>
    model.create_models()
  File "model.py", line 262, in create_models
    initial_state=[encoder_output])
  File "/root/anaconda3/lib/python3.6/site-packages/keras/engine/base_layer.py", line 460, in __call__
    output = self.call(inputs, **kwargs)
  File "/root/anaconda3/lib/python3.6/site-packages/kulc-0.0.5-py3.6.egg/kulc/attention.py", line 395, in call
    input_length=input_shape[1]
  File "/root/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2868, in rnn
    outputs, _ = step_function(inputs[0], initial_states + constants)
  File "/root/anaconda3/lib/python3.6/site-packages/kulc-0.0.5-py3.6.egg/kulc/attention.py", line 356, in step
    total_x_prod = states[3]
IndexError: list index out of range

The problem appears to be here https://github.com/FlashTek/keras-utility-layer-collection/blob/master/kulc/attention.py#L350-L356

If you intended for this to be used with a particular version of Keras, you could use a requirements.txt file to indicate the version of Keras you wanted to use.

Thanks for working on this it seems super useful!

TypeError while running MultiHeadAttention

While trying out MultiHeadAttention I got this error:

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (100, Dimension(100)). Consider casting elements to a supported type.

and I fixed it like this:

    def build(self, input_shape):
        self._validate_input_shape(input_shape)
        
        d_k = self._d_k if self._d_k else input_shape[1][-1]
        d_model = self._d_model if self._d_model else input_shape[1][-1]
        d_v = self._d_v

        if type(d_k) == tf.Dimension:
            d_k = d_k.value

        if type(d_model) == tf.Dimension:
            d_model = d_model.value
        
        self._q_layers = []
        self._k_layers = []
        self._v_layers = []
        self._sdp_layer = ScaledDotProductAttention(return_attention=self._return_attention)

Great project!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.