Comments (8)
Close it for now. Feel free to re-open it when needed!
from pytorch.
from pytorch.
thanks, above MNIST test code for Dispatchkey.XLA passed with self.bn1 = torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False)
by disable learnable parameters
from pytorch.
For a batch norm mod, its states are lifted as inputs to the body_fn and we create a torch.ops.aten._native_batch_norm_legit_functional.default
in functionalization pass. This operator is non-mutating and can pass the mutating input check. We can check this behavior by exporting the model:
ep = torch.export.export(mnist, (iteri, l_in_0, l_out))
ep.module().print_readable()
But seems that F.batch_norm is turned into a different operator in your case. Can you show what operator the batch norm is turned into? Or is it because batch norm is decomposed or something?
from pytorch.
yes, sure,
with code self.bn1 = torch.nn.BatchNorm2d(10)
: code, failed and print: https://gist.github.com/ManfeiBai/c4672daa3ce35e7f5e2e1b1c6303561d
with code self.bn1 = torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False)
: code, passed and print: https://gist.github.com/ManfeiBai/0583b2b084a5cc276536718dce3571be
compared different print, torch.nn.BatchNorm2d(10)
has more args than torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False)
:
bn1_bias, bn1_num_batches_tracked, bn1_running_mean, bn1_running_var, bn1_weight, bn2_bias, bn2_num_batches_tracked, bn2_running_mean, bn2_running_var, bn2_weight
and looks like torch.nn.BatchNorm2d(10, affine=False, track_running_stats=False)
moved BatchNorm2d
layer's weight and bias to static too
does these args are learnable args mentioned in [doc](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#:~:text=affine%20(bool,modes.%20Default%3A%20True)?
from pytorch.
btw, @ydwu4, do we expect torch.nn.BatchNorm2d
is non-mutating when we use while_loop
?
And iiuc, for mutating ops, while_loop
would call pure python while loop? or we suggest user to use scan
for mutating ops?
from pytorch.
does these args are learnable args mentioned
Yeah, they're learnable in training. Is the mnist demo for inference only? If it's only for inference, setting affine=False, track_running_stats=False
may not be the proper approach to unblock. Can you try self.bn2 = torch.nn.BatchNorm2d(20).eval()
? this is better than setting the two flags.
for mutating ops, while_loop would call pure python while loop?
Generally speaking, if there isn't a functionalization key, while_loop goes here. All mutating ops are allowed.
But in your case, I think there's a functionalize
transformation invoked somewhere in the stack, which triggers the functionalization key implementation of while_loop. For mutating ops, if they're mutating the inputs, i.e. producing side effect, while_loop will detect these mutating ops in functionalization key thus causing the error. Because, by definition, a function that's functionalized shouldn't have side-effects. While_loop doesn't handle this case currently, so we have to make the operators not mutating inputs in user program. In your case, we should set the batch_norm layer to be non-mutating.
from pytorch.
Is the mnist demo for inference only?
yes, this mnist demo is for inference only
Can you try self.bn2 = torch.nn.BatchNorm2d(20).eval()?
yeah, sure, tested and passed locally too
Thanks for the detailed info, Yidi, I would use affine=False, track_running_stats=False
to set batch_norm layer to be non-mutating in this inference test case for now
cc @JackCaoG
from pytorch.
Related Issues (20)
- DISABLED test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_128_seq_len_k_128_head_dim_16_is_causal_False_dropout_p_0_22_float16_scale0_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 2
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_160_is_causal_False_dropout_p_0_22_bfloat16_scale0_cuda_bfloat16 (__main__.TestSDPACudaOnlyCUDA) HOT 5
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_128_is_causal_True_dropout_p_0_0_float16_scale_l1_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_160_is_causal_True_dropout_p_0_0_bfloat16_scale_l1_cuda_bfloat16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- Dynamo errors on `import ... as ...`
- DISABLED test_comprehensive_argsort_cuda_int64 (__main__.TestInductorOpInfoCUDA) HOT 1
- DISABLED test_slice_with_floordiv_training_ir_to_decomp_non_strict (__main__.TrainingIRToRunDecompExportNonStrictTestExport) HOT 5
- DISABLED test_mnist_exported_with_no_warnings_optional_inputs (__main__.TestFxToOnnx) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_192_is_causal_False_dropout_p_0_22_float16_scale0_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_mem_efficient_attention_attn_mask_vs_math_ref_grads_batch_size_1_seq_len_q_128_seq_len_k_128_head_dim_32_is_causal_False_dropout_p_0_0_float32_scale_l1_cuda_float32 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_128_is_causal_False_dropout_p_0_48_bfloat16_scale0_cuda_bfloat16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_160_is_causal_False_dropout_p_0_0_float16_scale_l1_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_160_is_causal_True_dropout_p_0_48_float16_scale0_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_192_is_causal_True_dropout_p_0_0_float16_scale0_cuda_float16 (__main__.TestSDPACudaOnlyCUDA) HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_16_is_causal_True_dropout_p_0_22_bfloat16_scale0_cuda_bfloat16 (__main__.TestSDPACudaOnlyCUDA) HOT 5
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_128_is_causal_False_dropout_p_0_0_bfloat16_scale_l1_cuda_bfloat16 (__main__.TestSDPACudaOnlyCUDA) HOT 2
- DISABLED test_scaled_dot_product_attention_3D_input_dim_no_attn_mask_dropout_p_0_2_cuda (__main__.TestTransformersCUDA) HOT 4
- The empty 1D or more D tensor for `median()/nanmedian()` with the deepest `dim` gets errors, not getting NaNs
- The empty 1D or more D tensor for `nanmedian()` with the deepest `dim` gets errors, not getting NaNs HOT 1
- Torch.compile does not recompile when called with different options but the same backend HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.