Comments (9)
@lucidfrontier45 Hi, your suggestion is good.
Now, BatchNormLayer
have is_train
, but DropoutLayer
doesn't. However, if a model contails BatchNormLayer
, to build inferences for training and testing, we need to use the way in PTB example. In that case, we can use
if is_train:
network = DropoutLayer(network, 0.8, name='xxx')
instead of put the is_train
inside the DropoutLayer
, or we can also enable/disable dropout layer by setting feed_dict
see mnist cnn.
Please let me know, if you have any suggestion.
from tensorlayer.
I think one way to unify the API is add a new Dropout
layer that receives is_train
argument.
See by test implementation below.
class Dropout(Layer):
def __init__(self,
layer = None,
keep = 0.5,
is_train = True,
name = 'dropout_layer'):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))
set_keep[name] = tf.constant(keep, dtype=tf.float32)
if is_train:
self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name) # 1.2
else:
self.outputs = self.inputs
self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)
self.all_drop.update( {set_keep[name]: keep} )
self.all_layers.extend( [self.outputs] )
from tensorlayer.
[NEW] FYI, the lastest version of DropoutLayer
have a is_fix
setting, you can fix the keeping probability by setting it to True
.
Previous answer:
This may be better?
class Dropout(Layer):
def __init__(self,
layer = None,
keep = 0.5,
is_fix = False,
name = 'dropout_layer'):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))
if is_fix:
self.outputs = tf.nn.dropout(self.inputs, keep, name=name)
else:
set_keep[name] = tf.placeholder(tf.float32)
self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name)
self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)
if not is_fix:
self.all_drop.update( {set_keep[name]: keep} )
self.all_layers.extend( [self.outputs] )
from tensorlayer.
@lucidfrontier45 Is @zsdonghao 's code work for you? if yes, you can make a push request.
from tensorlayer.
Before talking about my code or @zsdonghao 's I want to make clear how to use batch normalization.
class tensorlayer.layers.BatchNormLayer(
layer = None,
decay = 0.999,
epsilon = 0.00001,
act = tf.identity,
is_train = None,
beta_init = tf.zeros_initializer,
gamma_init = tf.ones_initializer,
name ='batchnorm_layer')
BatchNormLayer
accepts is_train
as arg for constructor. It's compile time but not run time.
I couldn't find any example of batch normalization except in DCGAN example. It makes two net, one with is is_train=True
passed and is_train=False
the other. Is this the intended usage of BatchNormLayer
?
If so, I think it's confusing that DropoutLayer
and BatchNormLayer
has different API for switching training/test phase and should make it unify.
One way is my Dropout
implementation that accepts is_train
argument.
Is @zsdonghao 's code for switching training/test phase?
from tensorlayer.
FYI, the lastest version of DropoutLayer
have a is_fix
setting, you can fix the keeping probability by setting it to True
.
from tensorlayer.
if is_train:
network = DropoutLayer(network, 0.8, name='xxx')
This looks fine to me. Thank you.
from tensorlayer.
IMPORTANT
@lucidfrontier45 @wagamamaz the latest version of TL has an args of is_fix
, so you can do as follow:
if is_train:
network = DropoutLayer(network, 0.8, is_fix=True, name='xxx')
from tensorlayer.
network = Conv2d(net_in, df_dim, (k, k), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding='SAME', W_init=w_init, name='h0/conv2d')
tf.summary.histogram('h0/conv2d',tf.get_collection(tf.GraphKeys.VARIABLES, 'h0/conv2d'))
how to get the variable of network? the tensorboard shows nothing.
from tensorlayer.
Related Issues (20)
- Possible Arbitrary code execution bug. HOT 5
- examples/reinforcement_learning/tutorial_A3C.py Training failure to converge HOT 1
- SRGAN转为.pb HOT 2
- tl.layers.DropoutLayer 用于构建tf.estimator.Estimator,训练/预测模式切换时 报错 ‘ValueError: Variable model/relu1/W does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?’ HOT 2
- Question about the implementation of the 'Jaccard' Dice coefficient HOT 2
- Performance issues in examples/ HOT 3
- Performance issue in the definition of read_and_decode, examples/data_process/tutorial_tfrecord.py(P1) HOT 2
- Problem with the 2nd order derivative using TL activations
- module 'tensorflow.python.framework.ops' has no attribute '_TensorLike' , This error is reported after the program runs HOT 5
- How is the loss calculated about actor in A3C
- Questions about PPO HOT 1
- AttributeError: module 'tensorflow.python.framework.ops' has no attribute '_TensorLike' HOT 2
- AttributeError: 'str' object has no attribute 'decode' HOT 6
- IndexError: list index out of range
- Grammatical error HOT 1
- A problem about using cuda() HOT 2
- module 'tensorflow' has no attribute 'placeholder' HOT 1
- 如何使用贝尔曼期望方程计算价值函数 V (s)?
- 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_info' HOT 1
- There are errors in the source code(源代码写的有错误)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tensorlayer.