minqi / hnatt Goto Github PK
View Code? Open in Web Editor NEWTrain and visualize Hierarchical Attention Networks
License: MIT License
Train and visualize Hierarchical Attention Networks
License: MIT License
@minqi Any plan to update the Attention(Layer)
to the latest version of tensorflow?
Thanks!
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
texts_input (InputLayer) (None, 1, 20) 0
_________________________________________________________________
time_distributed_1 (TimeDist (None, 1, 100) 99161300
_________________________________________________________________
bidirectional_2 (Bidirection (None, 1, 100) 45300
_________________________________________________________________
dense_transform_s (Dense) (None, 1, 100) 10100
_________________________________________________________________
sentence_attention (Attentio (None, 100) 100
_________________________________________________________________
dense_1 (Dense) (None, 25) 2525
=================================================================
Total params: 99,219,325
Trainable params: 99,219,325
Non-trainable params: 0
_________________________________________________________________
# Predict Label
def predict(x):
encoded_x = _encode_texts(x)
print(encoded_x.shape)
print(encoded_x)
return model.predict(encoded_x)
raw_text = [['the food was really good']]
predict(raw_text)
(1, 1, 20)
[[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 1. 146. 31. 623. 138.]]]
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-100-f093b3ce957d> in <module>()
8
9 raw_text = [['the food was really good']]
---> 10 predict(raw_text)
<ipython-input-100-f093b3ce957d> in predict(x)
5 print(encoded_x.shape)
6 print(encoded_x)
----> 7 return model.predict(encoded_x)
8
9 raw_text = [['the food was really good']]
~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
1170 batch_size=batch_size,
1171 verbose=verbose,
-> 1172 steps=steps)
1173
1174 def train_on_batch(self, x, y,
~/anaconda3/lib/python3.6/site-packages/keras/engine/training_arrays.py in predict_loop(model, f, ins, batch_size, verbose, steps)
295 ins_batch[i] = ins_batch[i].toarray()
296
--> 297 batch_outs = f(ins_batch)
298 if not isinstance(batch_outs, list):
299 batch_outs = [batch_outs]
~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2659 return self._legacy_call(inputs)
2660
-> 2661 return self._call(inputs)
2662 else:
2663 if py_any(is_tensor(x) for x in inputs):
~/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2629 symbol_vals,
2630 session)
-> 2631 fetched = self._callable_fn(*array_vals)
2632 return fetched[:len(self.outputs)]
2633
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args)
1452 else:
1453 return tf_session.TF_DeprecatedSessionRunCallable(
-> 1454 self._session._session, self._handle, args, status, None)
1455
1456 def __del__(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
517 None, None,
518 compat.as_text(c_api.TF_Message(self.status.status)),
--> 519 c_api.TF_GetCode(self.status.status))
520 # Delete the underlying status object from memory otherwise it stays alive
521 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: Inputs to operation bidirectional_2/while/Select_1 of type Select must have the same size and shape. Input 0: [1,1000] != input 1: [1,50]
[[Node: bidirectional_2/while/Select_1 = Select[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](bidirectional_2/while/Tile, bidirectional_2/while/add_6, bidirectional_2/while/Switch_3:1)]]
so, I trained a model with max_sentence_count = 1
and max_sentence_len = 20
. model is working exactly, and I can get the correct word_activation_maps
for each word. The issue is, when I try to predict the label
from the mode.predict(); I'm getting above error! I don't understand what's the reason of Input 0: [1,1000] != input 1: [1,50]
or how to fix this behavior!?
Any thoughts?
Hi. Please, i would like to do attention weight visualization on transformer like you did here. Please can you help me
It seems that the Attention layer is not properly computed. In the original paper, the vectors are computed as the weights sum of the weight and hidden state (h_i), but not the hidden representation (u_i).
The yelp-dataset has been updated resulting in all json files. Perhaps update the readme to link to an older version of the datasets (also available in Kaggle).
Can anyone tell me what this error means and how I can fix this?
I encountered an error. Does anyone has any suggestions? Please, and Thanks a lot!!!
Using TensorFlow backend. loading Yelp reviews... 0%| | 0/10000 [00:00<?, ?it/s]Traceback (most recent call last): File "main.py", line 10, in <module> (train_x, train_y), (test_x, test_y) = yelp.load_data(path=YELP_DATA_PATH, size=1e4, binary=False) File "/home/khanhng/Downloads/hnatt-master/util/yelp.py", line 48, in load_data df['text_tokens'] = df['text'].progress_apply(lambda x: normalize(x)) File "/home/khanhng/Downloads/hnatt-master/.venv/local/lib/python2.7/site-packages/tqdm/_tqdm.py", line 612, in inner result = getattr(df, df_function)(wrapper, **kwargs) File "/home/khanhng/Downloads/hnatt-master/.venv/local/lib/python2.7/site-packages/pandas/core/series.py", line 3194, in apply mapped = lib.map_infer(values, f, convert=convert_dtype) File "pandas/_libs/src/inference.pyx", line 1472, in pandas._libs.lib.map_infer File "/home/khanhng/Downloads/hnatt-master/.venv/local/lib/python2.7/site-packages/tqdm/_tqdm.py", line 608, in wrapper return func(*args, **kwargs) File "/home/khanhng/Downloads/hnatt-master/util/yelp.py", line 48, in <lambda> df['text_tokens'] = df['text'].progress_apply(lambda x: normalize(x)) File "/home/khanhng/Downloads/hnatt-master/util/text_util.py", line 11, in normalize doc = nlp(text) File "/home/khanhng/Downloads/hnatt-master/.venv/local/lib/python2.7/site-packages/spacy/language.py", line 346, in __call__ doc = self.make_doc(text) File "/home/khanhng/Downloads/hnatt-master/.venv/local/lib/python2.7/site-packages/spacy/language.py", line 378, in make_doc return self.tokenizer(text) TypeError: Argument 'string' has incorrect type (expected unicode, got str) Exception KeyError: KeyError(<weakref at 0x7f1109825f70; to 'tqdm' at 0x7f111a0d7490>,) in <bound method tqdm.__del__ of 0%| | 1/10000 [00:00<11:54, 13.99it/s]> ignored
运行run_hnatt_viewer.py出错
raise child_exception_type(errno_num, err_msg, err_filename)
FileNotFoundError: [Errno 2] No such file or directory: 'flask': 'flask'
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.