Comments (4)
@breadbread1984 Your code above looks OK to me (keep in mind the freeze_bert_layers
will not freeze the LayerNorm
layers, which is needed for adapter-BERT).
How do you call model.predict
?
from bert-for-tf2.
I just tokenize the input string pair and feed into bert.
my code is here
the bert is called at here
the preprocess code was copied from run_classifier.py from google official bert project.
really appreciate for your help.
from bert-for-tf2.
@breadbread1984 - I think, it doesn't work because, there is no Model
in your _classify()
function. Keras needs a Model instance to wire a TF graph. You have built a model in BERT.py
(e.g. calling model.build()
), but you have a classifier (dropout and dense layers belonging to no model) in _classify()
- and they are not wired into the execution graph. For this you either need a separate Model calling the bert-Model, or as an alternative, you can return the bert layer from BERT.py
, and build a single classifier model to be used in _classify()
(i.e. calling model.build()
in Predictor
only). Does this explanation make sense?
To make it work, you could change your code like this:
diff --git a/BERT.py b/BERT.py
index 722d4bb..1a3b7f4 100644
--- a/BERT.py
+++ b/BERT.py
@@ -39,7 +39,7 @@ def BERT(max_seq_len = 128, bert_model_dir = 'models/chinese_L-12_H-768_A-12', d
output = bert([input_token_ids, input_segment_ids]);
# create model containing only bert layer
model = tf.keras.Model(inputs = [input_token_ids, input_segment_ids], outputs = output);
- model.build(input_shape = [(None, max_seq_len), (None, max_seq_len)]);
+ #model.build(input_shape = [(None, max_seq_len), (None, max_seq_len)]);
# freeze_bert_layers
freeze_bert_layers(bert);
# load bert layer weights
diff --git a/Predictor.py b/Predictor.py
index 2c7e6da..dd1dcfe 100644
--- a/Predictor.py
+++ b/Predictor.py
@@ -151,15 +151,20 @@ class Predictor(object):
def _classify(self, inputs, mask, training = None):
# the first element of output sequence.
- outputs = self.bert(inputs, mask, training);
+ outputs = self.bert.predict(inputs)
+
+ cls_input = tf.keras.Input((128,768,), dtype=tf.float32)
# first_token.shape = (batch, hidden_size)
- first_token = tf.keras.layers.Lambda(lambda seq: seq[:, 0, :])(outputs);
- first_token = tf.keras.Dropout(rate = 0.5)(first_token);
+ first_token = tf.keras.layers.Lambda(lambda seq: seq[:, 0, :])(cls_input);
+ first_token = tf.keras.layers.Dropout(rate = 0.5)(first_token);
pooled_output = tf.keras.layers.Dense(units = first_token.shape[-1], activation = tf.math.tanh)(first_token);
dropout = tf.keras.layers.Dropout(rate = 0.5)(pooled_output);
logits = tf.keras.layers.Dense(units = 2, activation = tf.nn.softmax)(dropout);
- return logits;
+ model = tf.keras.models.Model(inputs=cls_input, outputs=logits)
+ model.build(input_shape = (None, 128, 768))
+
+ return model.predict(outputs)
def predict(self, question, answer):
@@ -167,6 +172,10 @@ class Predictor(object):
input_ids = tf.constant(input_ids, dtype = tf.int32);
input_mask = tf.constant(input_mask, dtype = tf.int32);
segment_ids = tf.constant(segment_ids, dtype = tf.int32);
+
+ input_ids = tf.expand_dims(input_ids, axis=0)
+ segment_ids = tf.expand_dims(segment_ids, axis=0)
+
logits = self._classify([input_ids, segment_ids], input_mask, False);
probabilities = tf.nn.softmax(logits);
out = tf.math.argmax(probabilities);
@@ -174,6 +183,7 @@ class Predictor(object):
if __name__ == "__main__":
+ tf.enable_eager_execution()
assert tf.executing_eagerly();
predictor = Predictor();
print(predictor.predict('今天天气如何?','感觉很不错!'));
but this is not what you want, as it first predicts on the BERT model, and then feeds its output to a second classifier model. A single classifier model with a bert layer would be be better (and easier to train).
Also note, that you need the batch dimension, when feeding data into the model (therefore the call to tf.expand_dims()
above).
from bert-for-tf2.
got it. thx for your informative and kindly reply!
from bert-for-tf2.
Related Issues (20)
- Custom tokenizer layer HOT 5
- ResourceExhaustedError: OOM when allocating tensor with shape[501153,768] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Mul]
- mixed precision HOT 3
- example (gpu_movie_reviews) has some mistake
- Failed to get weights from pretrained google model HOT 2
- Can not load pretrained bert weights when loading chinese_L-12_H-768_A-12/bert_model.ckpt HOT 3
- Paddings must be non-negative
- albert classification error(Failed copying input tensor from GPU in order to run Identity: GPU sync failed [Op:Identity])
- ValueError: Found unexpected keys that do not correspond to any Model output
- More comments for the code
- Can't train BERT with loaded weights on QA Task HOT 3
- Setting unexpected parameter 'name' in Params instance 'Params' HOT 2
- how to using this in functional model
- may be there is some problem work with tf hub
- AttributeError: module 'bert' has no attribute 'Layer'
- type error HOT 5
- Activation after bert-layer differs
- Count of weight not found[196]
- OSS License compatibility question
- tensorflow.python.keras.layer.input_spec should be replaced with tensorflow.keras.layers.InputSpec HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bert-for-tf2.