Comments (5)
Here is one idea of how this could be implemented:
######### in module.py
class ArgsOverride(Module):
def __init__(self, base_module, **kwargs):
self.base_module = base_module
self.kwargs = kwargs
def vars(self, scope):
# presence of ArgsOverride module won’t affect variable names
return self.base_module.vars(scope)
def __call__(self, *args, **kwargs):
kwargs.update(self.kwargs) # in practice this should only
# kwargs which are presented in
# signature of base_module
self.base_module(*args, **kwargs)
######### in utils.py
def reset_args_override(module):
# removes ArgsOverride from module and all submodules
…
############ in user code
model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.block_1.bn_1 = objax.ArgsOverride(model.block_1.bn_1, training=False)
model.block_2.bn_2 = objax.ArgsOverride(model.block_2.bn_2, training=False)
# reset all args override
objax.utils.reset_args_override(model)
# set override on different sub-module
model.block_3.bn_3 = objax.ArgsOverride(model.block_3.bn_3, training=False)
@rwightman do you have any feedback about this one?
from objax.
@AlexeyKurakin that could work... as implemented ArgsOverride it's quite generic and could be used for any arg. Is there any other arg passed throught the __call__
chain that you think one would want to override? If not, something like ForceNotTraining(Module) ForceTraining(Module) without the need to specifcy kwargs would be a bit more clear.
In the absence of additional functionality here, I was likely going to go a subclassing / alternate module impl route ... basically create FrozenBatchNorm, EvalBatchNorm, EvalDropout style classes and helpers to walk module hierarchy within subset of model and switch class types (and copy state). But then that'd run into the checkpoint compat issues discussed.
from objax.
Right now I can't really think of other arguments in addition to training
, though I guess it might be convenient to do so for some non-standard modules.
Also ArgsOverride
opens syntax like following:
# without ArgsOverride
predict = objax.Jit(lambda x: model(x, training=False), model.vars())
# with ArgsOverride
predict = objax.Jit(objax.ArgsOverride(model, training=False))
from objax.
@rwightman Eventually we decided to use name ForceArgs
for this feature and as I mentioned above it's somewhat more generic than simply forcing training flag. Right now change is merged into repository and available to be used.
from objax.
@AlexeyKurakin thanks for the heads up, looks good
from objax.
Related Issues (20)
- If user won't add random generator to VarCollection of jitted code then same number always return by random generator
- Model compiling twice when using jax==0.2.10 or later HOT 6
- Update Objax Basics tutorial to reflect .value change
- Accessing variable of a vectorized module HOT 2
- objax.Jit reports error when StateVar is added to the vc argument HOT 3
- Regression of JAX duck typing. HOT 2
- RecursionError when attempting to unpickle objax objects HOT 2
- How to compute Jacobian of outputs w.r.t. inputs HOT 3
- Activation functions like Swish and Mish are absent. HOT 1
- Error due to the deprecation of jax.api
- objax.Jacobian and objax.Hessian similar to objax.Grad HOT 4
- replacing jax.vmap with objax.Vectorize HOT 3
- Closure scoping for nested objax.Functions HOT 2
- ResNetV2 from model.zoo does not specify train arg for ResNetV2Block HOT 2
- pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus HOT 3
- `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]` HOT 1
- `objax.variable.VarCollection.update` not compliant with key-value assignment HOT 1
- TypeError during gradient computation: type <class 'objax.variable.TrainVar'> is not a valid JAX type
- Missing release 1.7.0 on GitHub HOT 2
- [Jax 0.4.27] AttributeError: arr.device_buffers has been deprecated.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from objax.