Git Product home page Git Product logo

Comments (9)

zsdonghao avatar zsdonghao commented on May 12, 2024 1

@lucidfrontier45 Hi, your suggestion is good.

Now, BatchNormLayer have is_train, but DropoutLayer doesn't. However, if a model contails BatchNormLayer, to build inferences for training and testing, we need to use the way in PTB example. In that case, we can use

if is_train:
    network = DropoutLayer(network, 0.8, name='xxx')

instead of put the is_train inside the DropoutLayer, or we can also enable/disable dropout layer by setting feed_dict see mnist cnn.

Please let me know, if you have any suggestion.

from tensorlayer.

lucidfrontier45 avatar lucidfrontier45 commented on May 12, 2024

I think one way to unify the API is add a new Dropout layer that receives is_train argument.
See by test implementation below.

class Dropout(Layer):
    def __init__(self,
                layer = None,
                keep = 0.5,
                is_train = True,
                name = 'dropout_layer'):
        
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        print("  tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))

        set_keep[name] = tf.constant(keep, dtype=tf.float32)
        if is_train:
            self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name) # 1.2
        else:
            self.outputs = self.inputs
            
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        self.all_drop.update( {set_keep[name]: keep} )
        self.all_layers.extend( [self.outputs] )

from tensorlayer.

zsdonghao avatar zsdonghao commented on May 12, 2024

[NEW] FYI, the lastest version of DropoutLayer have a is_fix setting, you can fix the keeping probability by setting it to True.


Previous answer:

This may be better?

class Dropout(Layer):
    def __init__(self,
                layer = None,
                keep = 0.5,
                is_fix = False,
                name = 'dropout_layer'):
        
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        print("  tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))

        if is_fix:
            self.outputs = tf.nn.dropout(self.inputs, keep, name=name) 
        else:
           set_keep[name] = tf.placeholder(tf.float32)
           self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name)
            
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        if not is_fix:
              self.all_drop.update( {set_keep[name]: keep} )
        self.all_layers.extend( [self.outputs] )

from tensorlayer.

wagamamaz avatar wagamamaz commented on May 12, 2024

@lucidfrontier45 Is @zsdonghao 's code work for you? if yes, you can make a push request.

from tensorlayer.

lucidfrontier45 avatar lucidfrontier45 commented on May 12, 2024

@wagamamaz

Before talking about my code or @zsdonghao 's I want to make clear how to use batch normalization.

class tensorlayer.layers.BatchNormLayer(
        layer = None,
        decay = 0.999,
        epsilon = 0.00001,
        act = tf.identity,
        is_train = None,
        beta_init = tf.zeros_initializer,
        gamma_init = tf.ones_initializer,
        name ='batchnorm_layer')

BatchNormLayer accepts is_train as arg for constructor. It's compile time but not run time.
I couldn't find any example of batch normalization except in DCGAN example. It makes two net, one with is is_train=True passed and is_train=False the other. Is this the intended usage of BatchNormLayer ?

If so, I think it's confusing that DropoutLayer and BatchNormLayer has different API for switching training/test phase and should make it unify.

One way is my Dropout implementation that accepts is_train argument.
Is @zsdonghao 's code for switching training/test phase?

from tensorlayer.

zsdonghao avatar zsdonghao commented on May 12, 2024

FYI, the lastest version of DropoutLayer have a is_fix setting, you can fix the keeping probability by setting it to True.

from tensorlayer.

lucidfrontier45 avatar lucidfrontier45 commented on May 12, 2024

@zsdonghao

if is_train:
    network = DropoutLayer(network, 0.8, name='xxx')

This looks fine to me. Thank you.

from tensorlayer.

zsdonghao avatar zsdonghao commented on May 12, 2024

IMPORTANT

@lucidfrontier45 @wagamamaz the latest version of TL has an args of is_fix, so you can do as follow:

if is_train:
    network = DropoutLayer(network, 0.8, is_fix=True, name='xxx')

from tensorlayer.

quelle1 avatar quelle1 commented on May 12, 2024
network = Conv2d(net_in, df_dim, (k, k), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding='SAME', W_init=w_init, name='h0/conv2d')
tf.summary.histogram('h0/conv2d',tf.get_collection(tf.GraphKeys.VARIABLES, 'h0/conv2d'))

how to get the variable of network? the tensorboard shows nothing.

from tensorlayer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.