Git Product home page Git Product logo

Comments (14)

ferreirafabio avatar ferreirafabio commented on May 16, 2024

Example:

def _build(self, input_op, num_processing_steps, is_training):
    latent = self._encoder(input_op, is_training)

--> fails with TypeError: _build() takes 2 positional arguments but 3 were given.

While _encoder is:

self._encoder = MLPGraphIndependent()

and MLPGraphIndependent is taken from the example and inherits snt.AbstractModule

from graph_nets.

vbapst avatar vbapst commented on May 16, 2024

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

Thank you for coming back to me. Unfortunately, I do care about both, GraphIndependent and GraphNetwork.

Would something like this work? I construct operations for a train and test graph as follows:

# train
self.model.is_training = True
self.model.output_ops_train = self.model(self.model.input_ph, self.config.n_rollouts, self.model.is_training, self.sess)
# test
self.model.is_training = False
self.model.output_ops_test = self.model(self.model.input_ph_test, self.config.n_rollouts, self.model.is_training, self.sess)

And assign the corresponding values to the is_training variable (created with tf.get_variable() in every encoder/decoder model) within the_build() function to modify the values in the TF graphs:

def _build(self, inputs, is_training, sess):
        out = self._network(inputs)

        # modify -is_training- flags accordingly
        with sess.as_default():
            for v in self._network.get_all_variables(collection=tf.GraphKeys.GLOBAL_VARIABLES):
                if "is_training" in v.name:
                    assign_op = v.assign(is_training)
                    sess.run(assign_op)
                    assert v.eval() == is_training

            # check if it is necessary to call _network(inputs) again
            variables = out[0].graph.get_collection("variables")
            for v in variables:
                if "is_training" in v.name:
                    assert v.eval() == is_training

        return out

As long as I maintain two different operation sets / GraphTuples lists (self.model.output_ops_test and self.model.output_ops_train) I believe this should work. What do you think?

Also, since the GraphTuples list are quite large, how can I check if these flags were correctly set?

Edited: your solutions works well for GraphIndependent modules. I would still love to see how you check if it's actually setting the flags correctly.

from graph_nets.

vbapst avatar vbapst commented on May 16, 2024

Hi Fabio,

I think the second solution I describe is exactly what you want and it should work with both GraphNetwork and a GraphIndependent ?

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024
outputs_test = inputs.replace(
    nodes=your_node_module(inputs.nodes, False),
    edges=your_edge_module(inputs.edges, False)
)
outputs_train = inputs.replace(
    nodes=your_node_module(inputs.nodes, True),
    edges=your_edge_module(inputs.edges, True)
)

--> just to be clear, is this done in the _build() or init() call?
Also, does this replace command need to be in a with ._enter_variable_scope(): block or is it enough if the module contains it?

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

It is further unclear to me, how I should call the replace function, wenn the node module itself initiates a class like this:

def __init__(self, model_id, is_training, name="a):
    with self._enter_variable_scope():
            visual_encoder = get_model_from_config(self.model_id, model_type="visual_encoder")(is_training=self.is_training, name="visual_encoder")
            self._network = modules.GraphIndependent(
                ...
                nodes=lambda: get_model_from_config(self.model_id, model_type="visual_and_latent_encoder")(visual_encoder, name="visual_and_latent_node_encoder")
                ...
             )

How do I pass the flag in a replace statement in this case? A little bit more details would be helpful. Or a self-contained minimum example. Thank you

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

@vbapst can you comment on this?

from graph_nets.

vbapst avatar vbapst commented on May 16, 2024

If you go down the first option, then wherever you build your network (probably in your main training loop, or maybe in the __init__ of you module), then you would construct outputs_test and outputs_train as described. You need to enter a variable scope only in the later case. Note that in this case you don't call modules.GraphIndependent, but directly call replace on the nodes, edges and globals of your graph.

I am not sure to understand what you last comment is trying to achieve as we don't usually pass the is_training flag at build time, but at init time. So it would look more like that:

def __init__(self):
  with tf.variable_scope("graph_modules"):
    self._edge_module = ..  # Define your module here. Their `_build` method should take an extra is_training argument
    self._node_module = ..
    self._global_module = ..

def _build(self, is_training)
  return modules.GraphNetwork(
  edge_model_fn=lambda: lambda x: self._edge_module(x, is_training=is_training),
  node_model_fn=lambda: lambda x: self._node_module(x, is_training=is_training),
  global_model_fn=lambda: lambda x: self._global_module(x, is_training=is_training)
)(inputs)

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

Thank you for your reply @vbapst. I meant the modules.GraphIndependent. For passing the is_training flag I now simply used the _build() functions of my models and omitted the _init() function like so:

    def __init__(self, name="Encoder"):
        super(Encoder, self).__init__(name=name)
  
    def _build(self, inputs, is_training, verbose=VERBOSITY):
  
        self._network = modules.GraphIndependent(
            edge_model_fn=lambda: ... (x, is_training=is_training)
            node_model_fn=lambda: ... (x, is_training=is_training)
            global_model_fn=lambda: ... (x, is_training=is_training)
            )
        return self._network(inputs)

My sanity check shows that the flags are set correctly. Does this also work in your opinion or can this have negative (that I currently do not foresee) consequences?

from graph_nets.

vbapst avatar vbapst commented on May 16, 2024

This looks good -- just check the variables are correctly shared

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

Thanks. Is there a way to check this within the GN framework?

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

@vbapst

from graph_nets.

vbapst avatar vbapst commented on May 16, 2024

You can just list the variables with tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) and check that there are not duplicated (i.e. you should only see one edge_module etc)

from graph_nets.

ferreirafabio avatar ferreirafabio commented on May 16, 2024

thanks! works!

from graph_nets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.