Comments (14)
Example:
def _build(self, input_op, num_processing_steps, is_training): latent = self._encoder(input_op, is_training)
--> fails with TypeError: _build() takes 2 positional arguments but 3 were given.
While _encoder
is:
self._encoder = MLPGraphIndependent()
and MLPGraphIndependent
is taken from the example and inherits snt.AbstractModule
from graph_nets.
from graph_nets.
Thank you for coming back to me. Unfortunately, I do care about both, GraphIndependent and GraphNetwork.
Would something like this work? I construct operations for a train and test graph as follows:
# train
self.model.is_training = True
self.model.output_ops_train = self.model(self.model.input_ph, self.config.n_rollouts, self.model.is_training, self.sess)
# test
self.model.is_training = False
self.model.output_ops_test = self.model(self.model.input_ph_test, self.config.n_rollouts, self.model.is_training, self.sess)
And assign the corresponding values to the is_training
variable (created with tf.get_variable()
in every encoder/decoder model) within the_build()
function to modify the values in the TF graphs:
def _build(self, inputs, is_training, sess):
out = self._network(inputs)
# modify -is_training- flags accordingly
with sess.as_default():
for v in self._network.get_all_variables(collection=tf.GraphKeys.GLOBAL_VARIABLES):
if "is_training" in v.name:
assign_op = v.assign(is_training)
sess.run(assign_op)
assert v.eval() == is_training
# check if it is necessary to call _network(inputs) again
variables = out[0].graph.get_collection("variables")
for v in variables:
if "is_training" in v.name:
assert v.eval() == is_training
return out
As long as I maintain two different operation sets / GraphTuples lists (self.model.output_ops_test
and self.model.output_ops_train
) I believe this should work. What do you think?
Also, since the GraphTuples list are quite large, how can I check if these flags were correctly set?
Edited: your solutions works well for GraphIndependent modules. I would still love to see how you check if it's actually setting the flags correctly.
from graph_nets.
Hi Fabio,
I think the second solution I describe is exactly what you want and it should work with both GraphNetwork
and a GraphIndependent
?
from graph_nets.
outputs_test = inputs.replace(
nodes=your_node_module(inputs.nodes, False),
edges=your_edge_module(inputs.edges, False)
)
outputs_train = inputs.replace(
nodes=your_node_module(inputs.nodes, True),
edges=your_edge_module(inputs.edges, True)
)
--> just to be clear, is this done in the _build()
or init()
call?
Also, does this replace command need to be in a with ._enter_variable_scope():
block or is it enough if the module contains it?
from graph_nets.
It is further unclear to me, how I should call the replace function, wenn the node module itself initiates a class like this:
def __init__(self, model_id, is_training, name="a):
with self._enter_variable_scope():
visual_encoder = get_model_from_config(self.model_id, model_type="visual_encoder")(is_training=self.is_training, name="visual_encoder")
self._network = modules.GraphIndependent(
...
nodes=lambda: get_model_from_config(self.model_id, model_type="visual_and_latent_encoder")(visual_encoder, name="visual_and_latent_node_encoder")
...
)
How do I pass the flag in a replace statement in this case? A little bit more details would be helpful. Or a self-contained minimum example. Thank you
from graph_nets.
@vbapst can you comment on this?
from graph_nets.
If you go down the first option, then wherever you build your network (probably in your main training loop, or maybe in the __init__
of you module), then you would construct outputs_test
and outputs_train
as described. You need to enter a variable scope only in the later case. Note that in this case you don't call modules.GraphIndependent
, but directly call replace
on the nodes, edges and globals of your graph.
I am not sure to understand what you last comment is trying to achieve as we don't usually pass the is_training flag at build time, but at init time. So it would look more like that:
def __init__(self):
with tf.variable_scope("graph_modules"):
self._edge_module = .. # Define your module here. Their `_build` method should take an extra is_training argument
self._node_module = ..
self._global_module = ..
def _build(self, is_training)
return modules.GraphNetwork(
edge_model_fn=lambda: lambda x: self._edge_module(x, is_training=is_training),
node_model_fn=lambda: lambda x: self._node_module(x, is_training=is_training),
global_model_fn=lambda: lambda x: self._global_module(x, is_training=is_training)
)(inputs)
from graph_nets.
Thank you for your reply @vbapst. I meant the modules.GraphIndependent
. For passing the is_training flag I now simply used the _build() functions of my models and omitted the _init()
function like so:
def __init__(self, name="Encoder"):
super(Encoder, self).__init__(name=name)
def _build(self, inputs, is_training, verbose=VERBOSITY):
self._network = modules.GraphIndependent(
edge_model_fn=lambda: ... (x, is_training=is_training)
node_model_fn=lambda: ... (x, is_training=is_training)
global_model_fn=lambda: ... (x, is_training=is_training)
)
return self._network(inputs)
My sanity check shows that the flags are set correctly. Does this also work in your opinion or can this have negative (that I currently do not foresee) consequences?
from graph_nets.
This looks good -- just check the variables are correctly shared
from graph_nets.
Thanks. Is there a way to check this within the GN framework?
from graph_nets.
from graph_nets.
You can just list the variables with tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
and check that there are not duplicated (i.e. you should only see one edge_module etc)
from graph_nets.
thanks! works!
from graph_nets.
Related Issues (20)
- GraphTuple from batched tensors does not offset Senders/Receivers HOT 3
- Support Apple Silicon HOT 6
- Error while importing sonnet about gast HOT 1
- AttributeError: module 'sonnet' has no attribute 'AbstractModule' HOT 3
- issue with passing *_model_kwargs parameter HOT 2
- how to build an heterogeneous graph network? HOT 3
- Question about repeat implementation HOT 2
- Inference - shortest path demo HOT 1
- What's the difference between graph_nets and jraph? HOT 3
- Performance issue in /graph_nets/tests (by P3) HOT 2
- Error when calling trained model: "AttributeError: tuple object has no attribute "as_list" HOT 3
- Performance issue HOT 9
- Training on batches of GraphsTuples? HOT 5
- Is this project still live? HOT 2
- Error while using the placeholder function from utils_tf HOT 1
- ZeroDivsion error? HOT 2
- no output from processor HOT 6
- Issue with understanding HOT 2
- TensorFlow 1 is not supported in Google Colab HOT 1
- Problem with plot_compare_graphs function HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from graph_nets.