Git Product home page Git Product logo

Comments (15)

gifflarn avatar gifflarn commented on August 30, 2024 2

@PauloQuerido I too am trying to get a frozen graph to work. I got the .pb file from the link you posted using his freeze_graph function with output_node_names=decoder/decoder/transpose_1
I am now stuck on using the frozen graph, since importing the graph yields me "You must feed a value to tensor, Placeholder_2 and Placeholder_3" which are tensors used in training (I think). It's weird because in test.py, running model.prediction with only three fed tensors works, but when frozen the model does not like me only using those three.
If you are able to progress further than this, please hear me out

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024 1

@gogasca From my understanding you only specify the last layer(s) from the graph as output nodes, 'freezing' everything between input and output node. I only specified decoder/decoder/transpose_1 as output node. And I hoped I could get it to work like this, without success
output = graph.get_tensor_by_name('prefix/decoder/decoder/transpose_1:0') input1 = graph.get_tensor_by_name('prefix/batch_size:0') input2 = graph.get_tensor_by_name('prefix/Placeholder:0') input3 = graph.get_tensor_by_name('prefix/Placeholder_1:0')

prediction = self.sess.run(output, feed_dict={ input1: len(batch), input2: batch, input3: batch_x_len})

from text-summarization-tensorflow.

dongjun-Lee avatar dongjun-Lee commented on August 30, 2024

Hi. I think what you want to do is implemented at test.py. When you run the train.py, the model is saved at every epoch. test.py loads the last saved model and it creates the summary of valid.article.filter.txt. If you have further questions, feel free to ask me. Thank you!

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024

@dongjun-Lee Do you possibly have any insight on why
model.decoder_input: batch_decoder_input,
model.decoder_len: batch_decoder_len,
model.decoder_target: batch_decoder_output
are needed in a frozen graph but not during a regular test.py session?

from text-summarization-tensorflow.

dongjun-Lee avatar dongjun-Lee commented on August 30, 2024

@gifflarn I'm sorry but I'm not familiar with the frozen graph. I'll look at it soon.

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

@gifflarn I tried to use the following code to extract the output node names:

[n.name for n in tf.get_default_graph().as_graph_def().node]

This is my code to freeze the Graph:

https://gist.github.com/gogasca/ac743e3664c3e9cb668e9666c9e7b025

I'm unable to restore PB an generate predictions.

While the test.py achieves reading file in a local environment what I want to do is offer an API. Anyone had any luck restoring the PB ?

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024

@gogasca isn't using every node in the graph as a output node counterproductive?
I am able to restore the .pb with dummy values for Placeholder_2 and Placeholder_3, but that gives me some bad results. So decoder_input and decoder_len obviously has some impact in testing as well. However, running test.py never initializes these tensors. I'm a bit confused by this

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

@gifflarn Is possible that I don't really need to enter all the output nodes, I'm just testing. I will continue working on it today, how did you freeze the model to .pb ?

The only difference I see b/w train and test is the the way run is executed and the parameters he passed.

Train

train_feed_dict = {
            model.batch_size: len(batch_x),
            model.X: batch_x,
            model.X_len: batch_x_len,
            model.decoder_input: batch_decoder_input,
            model.decoder_len: batch_decoder_len,
            model.decoder_target: batch_decoder_output
        }

_, step, loss = sess.run([model.update, model.global_step, model.loss],
                                 feed_dict=train_feed_dict)

Test

valid_feed_dict = {
            model.batch_size: len(batch_x),
            model.X: batch_x,
            model.X_len: batch_x_len,
        }

prediction = sess.run(model.prediction,
                                feed_dict=valid_feed_dict)

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

@gifflarn
I changed my script to use the SavedModelBuilder and now I can export and read the .pb successfully, but still facing similar issues as you described before:

Export to PB
https://gist.github.com/gogasca/305c14dea2ad342f163d3865e8576acd

Serving using .PB file
https://gist.github.com/gogasca/7d11b9cbb7f600fb3f4ecc026fa40929

Based on:
https://towardsdatascience.com/deploy-tensorflow-models-9813b5a705d5

I get this error when I run the second script in gist.

         valid_feed_dict = {
            batch_size: len(batch_x),
            X: batch_x,
            X_len: batch_x_len,
        }
        prediction = sess.run(transpose, feed_dict=valid_feed_dict)

Error

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_2' with dtype int32 and shape [?,15]
	 [[Node: Placeholder_2 = Placeholder[dtype=DT_INT32, shape=[?,15], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Any other suggestions?

I modified the utils.py to read text instead of a file, also replaced map and lambdas with list comprehension to improve readability.

def build_prediction(text, word_dict, article_max_len):   
    if not text:
        raise ValueError('Empty text')
    x = [word_tokenize(d) for d in _get_text_list(text)]
    x = [[word_dict.get(w, word_dict['<unk>']) for w in d] for d in x]
    x = [d[:article_max_len] for d in x]
    return [d + (article_max_len - len(d)) * [word_dict['<padding>']] for d in x]

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

I did a slight modification of test.py and now I'm able to do API requests.

This is just a workaround, as I haven't solved the export to .pb issue.
Please take a look at code below:

  1. Use the build_prediction function above.
  2. Replace map/lambdas with list comprehensions for readability.
"""Use a pre-train model."""

import tensorflow as tf
import pickle

from model import Model
from utils import build_dict, build_prediction, batch_iter

with open('args.pickle', 'rb') as f:
    args = pickle.load(f)

print('Loading dictionary...')
word_dict, reversed_dict, article_max_len, summary_max_len = build_dict('test',
                                                                        args.toy)
print('Loading validation dataset...')

sess = tf.Session()
print('Loading saved model...')
model = Model(reversed_dict, article_max_len, summary_max_len, args,
              forward_only=True)
saver = tf.train.Saver(tf.global_variables())
checkpoint = tf.train.get_checkpoint_state('./saved_model/')
initialize = tf.global_variables_initializer()
sess.run(initialize)
saver.restore(sess, checkpoint.model_checkpoint_path)


def summarize(text):
    """

    Args:
        text: (List) A Text array. Example ['This is a long text']

    """
    valid_x = build_prediction(text, word_dict, article_max_len)
    valid_x_len = [len([y for y in x if y != 0]) for x in valid_x]
    batches = batch_iter(valid_x, [0] * len(valid_x), args.batch_size, 1)
    for batch_x, _ in batches:
        batch_x_len = [len([y for y in x if y != 0]) for x in batch_x]
        valid_feed_dict = {
            model.batch_size: len(batch_x),
            model.X: batch_x,
            model.X_len: batch_x_len,
        }
        prediction = sess.run(model.prediction, feed_dict=valid_feed_dict)
        prediction_output = [[reversed_dict[y] for y in x] for x in
                             prediction[:, 0, :]]
        for line in prediction_output:
            summary = []
            for word in line:
                if word == '</s>':
                    break
                if word not in summary:
                    summary.append(word)
            return ' '.join(summary)

This is the Flask server:

"""Server"""

import summarizer
from flask import Flask, request, Response, json

app = Flask(__name__)


@app.route('/')
def index():
    return Response('TensorFlow text summarizer')


@app.route('/summary', methods=['POST'])
def process_text():
    """Process text."""
    try:
        if request.is_json:
            content = request.json
            text = content.get('text')
            summary = summarizer.summarize([text])
            if summary:
                return app.response_class(
                    response=json.dumps(summary),
                    status=200,
                    mimetype='application/json'
                )
        return app.response_class(
            response=json.dumps('No JSON content found'),
            status=400,
            mimetype='application/json'
        )
    except Exception as exception:
        print('POST /summary error: %e' % exception)
        return exception


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8081, debug=True)

API Request:

curl -H "Content-type: application/json" -X POST http://127.0.0.1:8081/summary -d '{"text": "australian foreign minister alexander downer called wednesday for the reform of the un security council and expressed support for brazil, india , japan and an african country to join the council ."}'

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024

I too got the script to answer to API calls, but it was very slow (~2s per sentence), hence why I am trying to freeze it. Could you time your solution? Maybe freezing the graph is not necessary.

And to your earlier comment, I did feed dummy values to Tensor_2 and Tensor_3, this yielded decimal values which could not be read from the dictionary, and if Flooring the decimals, I got really weird results.

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

Now is running in a Mac Pro ~16GB RAM/Intel Core i7.
Is taking ~1s which is not ideal. I would be happy with < 500 ms per request.

curl -o /dev/null -s -w 'Total: %{time_total}\n' -H "Content-type: application/json" -X POST http://127.0.0.1:8081/summary -d '{"text": "australian foreign minister alexander downer called wednesday for the reform of the un security council and expressed support for brazil, india , japan and an african country to join the council ."}'
Total: 1.094210

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024

@gogasca Did you progress any further on this? I kind of put this project on the bench for now, but I really want it to work, so if you have any ideas, I'm willing to try.

from text-summarization-tensorflow.

gogasca avatar gogasca commented on August 30, 2024

@gifflarn I'm resuming this project today, need to present some results within the next 2 weeks, I will update the progress.

from text-summarization-tensorflow.

gifflarn avatar gifflarn commented on August 30, 2024

@gogasca I had the idea of rewriting model.py and putting the decoder placeholders under a
if forward_only: statement and not let them get initialized at all during the testing phase. Do you think that might help?
Of course, that would mean retraining the model.

from text-summarization-tensorflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.