Git Product home page Git Product logo

Comments (8)

ydwu4 avatar ydwu4 commented on July 23, 2024 1

Close it for now. Feel free to re-open it when needed!

from pytorch.

JackCaoG avatar JackCaoG commented on July 23, 2024

cc @angelayi @ydwu4

from pytorch.

ManfeiBai avatar ManfeiBai commented on July 23, 2024

thanks, above MNIST test code for Dispatchkey.XLA passed with self.bn1 = torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False) by disable learnable parameters

from pytorch.

ydwu4 avatar ydwu4 commented on July 23, 2024

For a batch norm mod, its states are lifted as inputs to the body_fn and we create a torch.ops.aten._native_batch_norm_legit_functional.default in functionalization pass. This operator is non-mutating and can pass the mutating input check. We can check this behavior by exporting the model:

ep = torch.export.export(mnist, (iteri, l_in_0, l_out))
ep.module().print_readable()

But seems that F.batch_norm is turned into a different operator in your case. Can you show what operator the batch norm is turned into? Or is it because batch norm is decomposed or something?

from pytorch.

ManfeiBai avatar ManfeiBai commented on July 23, 2024

yes, sure,

with code self.bn1 = torch.nn.BatchNorm2d(10): code, failed and print: https://gist.github.com/ManfeiBai/c4672daa3ce35e7f5e2e1b1c6303561d

with code self.bn1 = torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False): code, passed and print: https://gist.github.com/ManfeiBai/0583b2b084a5cc276536718dce3571be

compared different print, torch.nn.BatchNorm2d(10) has more args than torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False):

bn1_bias, bn1_num_batches_tracked, bn1_running_mean, bn1_running_var, bn1_weight, bn2_bias, bn2_num_batches_tracked, bn2_running_mean, bn2_running_var, bn2_weight

and looks like torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False) moved BatchNorm2d layer's weight and bias to static too

does these args are learnable args mentioned in [doc](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#:~:text=affine%20(bool,modes.%20Default%3A%20True)?

from pytorch.

ManfeiBai avatar ManfeiBai commented on July 23, 2024

btw, @ydwu4, do we expect torch.nn.BatchNorm2d is non-mutating when we use while_loop?

And iiuc, for mutating ops, while_loop would call pure python while loop? or we suggest user to use scan for mutating ops?

from pytorch.

ydwu4 avatar ydwu4 commented on July 23, 2024

does these args are learnable args mentioned

Yeah, they're learnable in training. Is the mnist demo for inference only? If it's only for inference, setting affine=False, track_running_stats=False may not be the proper approach to unblock. Can you try self.bn2 = torch.nn.BatchNorm2d(20).eval()? this is better than setting the two flags.

for mutating ops, while_loop would call pure python while loop?

Generally speaking, if there isn't a functionalization key, while_loop goes here. All mutating ops are allowed.

But in your case, I think there's a functionalize transformation invoked somewhere in the stack, which triggers the functionalization key implementation of while_loop. For mutating ops, if they're mutating the inputs, i.e. producing side effect, while_loop will detect these mutating ops in functionalization key thus causing the error. Because, by definition, a function that's functionalized shouldn't have side-effects. While_loop doesn't handle this case currently, so we have to make the operators not mutating inputs in user program. In your case, we should set the batch_norm layer to be non-mutating.

from pytorch.

ManfeiBai avatar ManfeiBai commented on July 23, 2024

Is the mnist demo for inference only?

yes, this mnist demo is for inference only

Can you try self.bn2 = torch.nn.BatchNorm2d(20).eval()?

yeah, sure, tested and passed locally too

Thanks for the detailed info, Yidi, I would use affine=False, track_running_stats=False to set batch_norm layer to be non-mutating in this inference test case for now

cc @JackCaoG

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.