Git Product home page Git Product logo

Comments (5)

AlexeyKurakin avatar AlexeyKurakin commented on July 28, 2024

Here is one idea of how this could be implemented:

######### in module.py

class ArgsOverride(Module):

    def __init__(self, base_module, **kwargs):
        self.base_module = base_module
        self.kwargs = kwargs
  
    def vars(self, scope):
        # presence of ArgsOverride module won’t affect variable names
        return self.base_module.vars(scope)

    def __call__(self, *args, **kwargs):
        kwargs.update(self.kwargs)  # in practice this should only
                                    # kwargs which are presented in
                                    # signature of base_module
        self.base_module(*args, **kwargs)

######### in utils.py

def reset_args_override(module):
  # removes ArgsOverride from module and all submodules############ in user code

model = Resnet50(nclasses=1000)
…
# Here is example how to set most of the network,
# except few modules into training mode
model.block_1.bn_1 = objax.ArgsOverride(model.block_1.bn_1, training=False)
model.block_2.bn_2 = objax.ArgsOverride(model.block_2.bn_2, training=False)

# reset all args override
objax.utils.reset_args_override(model)

# set override on different sub-module
model.block_3.bn_3 = objax.ArgsOverride(model.block_3.bn_3, training=False)

@rwightman do you have any feedback about this one?

from objax.

rwightman avatar rwightman commented on July 28, 2024

@AlexeyKurakin that could work... as implemented ArgsOverride it's quite generic and could be used for any arg. Is there any other arg passed throught the __call__ chain that you think one would want to override? If not, something like ForceNotTraining(Module) ForceTraining(Module) without the need to specifcy kwargs would be a bit more clear.

In the absence of additional functionality here, I was likely going to go a subclassing / alternate module impl route ... basically create FrozenBatchNorm, EvalBatchNorm, EvalDropout style classes and helpers to walk module hierarchy within subset of model and switch class types (and copy state). But then that'd run into the checkpoint compat issues discussed.

from objax.

AlexeyKurakin avatar AlexeyKurakin commented on July 28, 2024

Right now I can't really think of other arguments in addition to training, though I guess it might be convenient to do so for some non-standard modules.

Also ArgsOverride opens syntax like following:

# without ArgsOverride
predict = objax.Jit(lambda x: model(x, training=False), model.vars())

# with ArgsOverride
predict = objax.Jit(objax.ArgsOverride(model, training=False))

from objax.

AlexeyKurakin avatar AlexeyKurakin commented on July 28, 2024

@rwightman Eventually we decided to use name ForceArgs for this feature and as I mentioned above it's somewhat more generic than simply forcing training flag. Right now change is merged into repository and available to be used.

from objax.

rwightman avatar rwightman commented on July 28, 2024

@AlexeyKurakin thanks for the heads up, looks good

from objax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.