Git Product home page Git Product logo

Comments (6)

Zeleni9 avatar Zeleni9 commented on August 19, 2024 3

Hi I found solution. This middle part code block starting with fixed nodes, solves the problem.

def model():
    print("Trying to import Gaze Model.")
    dir = os.path.dirname(os.path.realpath(__file__))+'/gaze'
    pb = glob.glob('%s/*.pb'%(dir))[0]

    # Read graph definition
    with tf.gfile.FastGFile(pb, 'rb') as f:
        gd = graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

	# Fix nodes of freezed model
        for node in graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        # Export fixed freezed model pb file.
        with tf.gfile.FastGFile('./gaze_better.pb', mode='wb') as model_fixed:
           model_fixed.write(graph_def.SerializeToString())

        # Import graph into session
        tf.import_graph_def(graph_def, name='')
  
    # Saving the important nodes of the Gaze model
    # Input nodes of the model
    frame_index = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:0')
    eye = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:1')
    eye_index = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:2')
    
    # Output nodes of the model
    heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_2/after/hmap/conv/BiasAdd:0')
    landmarks = sess.graph.get_tensor_by_name('upscale/mul:0')
    radius = sess.graph.get_tensor_by_name('radius/out/fc/BiasAdd:0')
    sess.run(tf.global_variables_initializer())
    return eye,heatmaps,landmarks,radius

from gazeml.

Zeleni9 avatar Zeleni9 commented on August 19, 2024 2

Well I have added commented part of the code from this answer to issue - #23 (comment) into the code in file: https://github.com/swook/GazeML/blob/master/src/core/model.py inside at the end of method def inference_generator(self). It is the code on line 9-19 on this link https://github.com/parai/dms/blob/master/models/gaze.py.

The idea is that exports the model on inference_generator call, but it is only sufficient to save it once and stop the inference_generator.

Here I tried both versions for webcam or video so the names were 'Webcam/fifo_queue_DequeueMany' or 'Video/fifo_queue_DequeueMany', the graph will export with name you put in the code. Hope this helps.

from gazeml.

Zeleni9 avatar Zeleni9 commented on August 19, 2024 1

I don't know about OpenCV, but I used command above to get freezed model and load it in tensorflow for inference.

with tf.gfile.FastGFile('./gaze_better.pb', mode='wb') as model_fixed: model_fixed.write(graph_def.SerializeToString())

from gazeml.

funkfuzz avatar funkfuzz commented on August 19, 2024

@parai I am getting the same issue, did you find a solution?

from gazeml.

funkfuzz avatar funkfuzz commented on August 19, 2024

@Zeleni9 @parai has anyone of you managed to export a .pb model that can be succesfully imported in OpenCV with readNetFromTensorFlow()?

from gazeml.

funkfuzz avatar funkfuzz commented on August 19, 2024

@Zeleni9, how did you manage to export a frozen model with a node 'Webcam/fifo_queue_DequeueMany' ?
When I try to do it, by first loading the metagraph from the checkpoints and then using the code from @parai, tensorflow tells me that there is no node named 'Webcam/fifo_queue_DequeueMany'.
However if I use 'UnityEyes/random_shuffle_queue_DequeueMany' it manages to export it fine.
Is it just a typing mistake or am I missing something?

Here is my code:

# This function exports a saved model

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util

# trained_checkpoint_prefix = 'checkpoints/dev'
trained_checkpoint_prefix = 'model-4672654'
export_dir = os.path.join('models', 'GazeML_010520') # IMPORTANT: each model folder must be named '0', '1', ... Otherwise it will fail!

# handle unitialized variables
def initialize_uninitialized(sess):
    global_vars          = tf.compat.v1.global_variables()
    is_not_initialized   = sess.run([tf.compat.v1.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    print ([str(i.name) for i in not_initialized_vars]) # only for testing
    if len(not_initialized_vars):
        sess.run(tf.compat.v1.variables_initializer(not_initialized_vars))

loaded_graph = tf.Graph()
with tf.compat.v1.Session(graph=loaded_graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
    loader.restore(sess, trained_checkpoint_prefix)
    initialize_uninitialized(sess)
    
    constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph_def,
            ['hourglass/hg_2/after/hmap/conv/BiasAdd', # heatmaps
             'upscale/Mean', # landmarks
             'radius/out/fc/BiasAdd', # radius
             'UnityEyes/random_shuffle_queue_DequeueMany', # frame_index, eye, eye_index
            ])
    with tf.gfile.FastGFile('./saved_GazeML.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

any help will be much appreciated! :)

from gazeml.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.