Comments (24)
cc @shazqadeer
from pytorch.
Is this related also to these recompilations or not?
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] Recompiling function forward in /workspace/networks/encoders/swin/swin_transformer.py:425
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] triggered by the following guard failure(s):
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - L['H'] == 272
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - Eq(IntTrueDiv(L['H'], 7), 9.71428571428571) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - Eq(IntTrueDiv(L['H'], 7), 9.71428571428571) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_433 in /workspace/networks/encoders/swin/swin_transformer.py:433
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] triggered by the following guard failure(s):
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - L['W'] == 272
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - Eq(IntTrueDiv(L['W'], 7), 9.71428571428571) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - Eq(IntTrueDiv(L['W'], 7), 9.71428571428571) # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714) # _dynamo/output_graph.py:451 in init_ambient_guards
from pytorch.
to see why it's recompiling, run it with TORCH_LOGS=explain
from pytorch.
Do we have explain
?
Valid settings:
all, dynamo, aot, autograd, inductor, dynamic, torch, distributed, c10d, ddp, pp, fsdp, onnx, export, compiled_autograd_verbose, trace_call, bytecode, aot_joint_graph, fusion, custom_format_test_artifact, output_code, trace_source, aot_graphs, recompiles, onnx_diagnostics, not_implemented, verbose_guards, graph, graph_sizes, kernel_code, sym_node, graph_code, ddp_graphs, graph_breaks, post_grad_graphs, cudagraphs, perf_hints, trace_bytecode, guards, compiled_autograd, schedule, recompiles_verbose, overlap
from pytorch.
sorry , I meant TORCH_LOGS=recompiles
, but it's the one you have used.
In your logs, it's clear that H / 7 != 9.71428571428571 and H / 7 != 19.4285714285714 for H = 272. Same for W, so it sounds reasonable that it's recompiling (and probably tracking it with a symint to not recompile later), no?
from pytorch.
Probably, but I am trying to investigate why the compilation never end cause I have only this message and this other one at #127677 (comment)
from pytorch.
OK, here's another interesting application of reasoning about rationals.
Suppose we have s0 in [4, 10] and s0 = s1 * 2 (but s1 is otherwise unbounded). We would like to replace s0 with s1 * 2, but if we do so naively, we will lose knowledge that s0 in [4, 10]. We would like to update the bounds on s1. One way to do this is solve for s1 in terms of s0. However, the solution for this equation is s1 = s0 / 2: rationals! In #126905 I made these solutions illegal, and as a result also made it impossible to do replacements when the thing being replaced had min/max bounds.
The fix is to temporarily allow for rational compute, and then when you're all done requantize back to integers. So maybe this is a good reason to allow for rationals.
from pytorch.
@ezyang : Your example above is interesting. If you did allow for reasoning with rationals, would the upper and lower bounds of value ranges for ints become rational numbers?
from pytorch.
@ezyang : I believe I have your permission to ask random questions on your PRs and issues to fill background gaps. If not, ignore this message.
The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?
from pytorch.
If you did allow for reasoning with rationals, would the upper and lower bounds of value ranges for ints become rational numbers?
You can end up computing a rational upper/lower bound, but it's always OK to round them, since if x in [1/2, 3/2], and x is integer, then actually it could only ever be 1.
from pytorch.
The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?
The offline solver is all the code in DimConstraints. It's offline because we assume we've collected all of the guards and then can solve them all together. The online solver in ShapeEnv has to be able to answer queries as we go, since symbolic evaluation cannot proceed without answers.
from pytorch.
Suppose we have s0 in [4, 10] and s0 = s1 * 2 (but s1 is otherwise unbounded). [...]
I mean, s1 = s0 / 2
is a rationa, but you can represent it with your SymPy functions, which is completely fine, right?
from pytorch.
I was going to say "but the division is generated by sympy solver" but we already reimplemented the solver so I can fix it.
But I'm not too sure what the rounding behavior should be if I turn this into an integer division. Let's suppose I have s0 in [3, 11] and s0 = s1 * 2. The desired refined range for s1 is [2, 5]. If I solve s1 = s0 // 2, I will get a less accurate s1 in [1, 5]. I need distinct rounding behavior for the upper and lower bound.
from pytorch.
To get tight bounds you want to map [lower, upper]
into [CeilDiv(lower, div), FloorDiv(upper, div)]
.
from pytorch.
actually that's not true, you first need to potentially normalise multiplying the range (i.e. potentially flip it) by the sign of div
to make sure div
is positive, but yeah.
If div
is a range, you first make sure that 0 is not in the div range, and then compute as above.
from pytorch.
My point is that, the way the reverse propagation is currently written, there's no opportunity to directly remap the ValueRange. We have equation s0 = s1 * 2, we solve s1 = s0 / 2, and then we run value ranges prop on the new expression. Concretely, to do what you suggest, I need to introduce some new weird div operator s1 = MyDiv(s0, 2) that has the correct value ranges formula you described and... I'm not even sure what its runtime semantics are, maybe this is CleanDiv (where we guarantee that the output is integral?) Or are you suggesting that we stop using try_solve for reverse propagation and do it some other way?
from pytorch.
So, there are a few possibilities here. The first one, which should be quite uncontroversial, is to check that lower % div == 0 and upper % div == 0
. In this case, every division agrees, and you can say s1 = FloorDiv(s0, div)
and compute rg_s1 = cls.floordiv(rg_s0, div)
(or using TruncDiv
if you know that you are going to generate this in a kernel with C semantics).
I would assume that most of the uses would fall in this case.
Then, for the general case, we could have an op handler that coerces a float range into an int range as described above. We could describe this as Coerce(FloatTrueDiv(s0, 2))
. Then, after the solve, we'd need to propagate the ranges from s1 to s0 using the initial equation s0 = s1 * 2
, and finally turn Coerce o FloatTrueDiv
into FloorDiv
(or TruncDiv
), now that we have simplified the problem to the previous one.
from pytorch.
While I agree your first proposal is sound, it feels very edge casey to me. It also changes some of our layering; right now our try_solve is value range oblivious, but for your proposal, we would need to consult the value ranges to determine if we could do a replacement.
For the general solution, there are problems. First, I don't want to use a FloatTrueDiv, because s0 and 2 are not floats, they are integers, and to satisfy typing I'd have to coerce them to floats too. And now this looks very suspicious: why am I doing floating point division for the reasoning here? In particular, I can now torture you with something like my integer is out of the exactly representable floating range? Like, I know this is not going to happen in practice, but if we just want something that works, doing rationals works too!
from pytorch.
I mean, I'm team rationals, so you don't have to convince me on the "rationals work" front :P In particular, what about implementing the general solution but using sympy.Div instead of FloatTrueDiv
? In particular, we'd temporarily have a Div, but we'd turn it straight away into one of our divs, so life is good?
from pytorch.
But you're also on team asserts, and I'm giving this up here. The rational logic with sympy.Div solution is what has landed to main. This works great... I just can't assert that stuff is_integer
now, because it might not. In particular, the asserts cannot live in eval, because even if the div is temporary, eval would get run immediately when I create the expression.
from pytorch.
Yeah, eval running eagerly and not having any sort of global information means that it may not be the best place to put the asserts. We might want to have a pass that simplifies the expressions using global information (VRs and otherwise) and then, after that, assert that the expressions have certain postconditions.
I'm not sure whether we have a place like that at the moment in the code, but would certainly be desirable.
from pytorch.
A simple place we can do asserts is in guard accumulation, since this is where expressions persist. Otherwise, we can do it on SymNode creation. But I am worried about the latter as the invariant test is recursive so you will keep iterating into the same structs over and over again.
from pytorch.
The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?
The offline solver is all the code in DimConstraints. It's offline because we assume we've collected all of the guards and then can solve them all together. The online solver in ShapeEnv has to be able to answer queries as we go, since symbolic evaluation cannot proceed without answers.
Thanks. Why is an offline solver needed?
from pytorch.
Export has a feature which is that if you specify dimensions as dynamic, but they are not fully dynamic due to guards, it will try to compute the simplest set of constraints you could specify to precisely specify what extra constraints your guards have imposed. Also cc @avikchaudhuri
from pytorch.
Related Issues (20)
- [inductor] Large regression when inlining is enabled - inductor seems sensitive to input ordering
- torch.fx.Tracer.record_stack_traces is broken in torch 2.4.0
- CMake cannot detect non-default fp16 installation path HOT 1
- `size()` type for `size` argument of `rand()/randn()/randint()` works against the doc HOT 1
- Cannot load torchscript model on windows
- BertForSequenceClassification.from_pretrained broken when using FSDP HOT 1
- `size()` type for `size` argument of `randn()` works against the doc HOT 1
- cpp wrapper doesn't work with cudagraphs HOT 1
- `size()` type for `size` argument of `randint()` works against the doc HOT 1
- `size()` type for `size` argument of `normal()` works against the doc HOT 1
- `size()` type for `size` argument of `zeros()` works against the doc HOT 1
- `size()` type for `size` argument of `ones()` works against the doc HOT 1
- `size()` type for `size` argument of `empty()` works against the doc HOT 1
- `size()` type for `size` argument of `empty_strided()` works against the doc HOT 1
- `size()` type for `shape` argument of `reshape()` works against the doc HOT 1
- `size()` type for `dims` argument of `tile()` works against the doc HOT 2
- [cudagraph] torch.compile(backend="cudagraphs") + StableDiffusion2.1 doesn't work HOT 4
- `torch.histogram` test cases have out-of-order sequence as bins input
- Torch compile can't support Tensor.item
- [torch.compile][distributed] Collective ops in FX graph should not be erased by DCE pass even its result is not used in current rank. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.