Git Product home page Git Product logo

Comments (3)

laithsakka avatar laithsakka commented on July 2, 2024

running with TORCH_LOGS="recompiles" with out inlining:

[[email protected] /data/users/lsakka/pytorch/pytorch (fix_op_count)]$ TORCH_LOGS="recompiles" python test/dynamo/test_skip_non_tensor.py -k test_do_not_skip_side_effects
V0529 16:36:59.419000 140372644329280 torch/_dynamo/guards.py:2598] [1/1] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:59.419000 140372644329280 torch/_dynamo/guards.py:2598] [1/1] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:59.419000 140372644329280 torch/_dynamo/guards.py:2598] [1/1] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.438000 140372644329280 torch/_dynamo/guards.py:2598] [1/3] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:59.438000 140372644329280 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:59.438000 140372644329280 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.438000 140372644329280 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - ___check_obj_id(L['self'], 140369290455472)                 
V0529 16:36:59.438000 140372644329280 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - ___check_obj_id(L['self'], 140369275927328)                 
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - ___check_obj_id(L['self'], 140369290455472)                 
V0529 16:36:59.456000 140372644329280 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - ___check_obj_id(L['self'], 140369272172752)                 
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - ___check_obj_id(L['self'], 140369275927328)                 
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - ___check_obj_id(L['self'], 140369290455472)                 
V0529 16:36:59.476000 140372644329280 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                          

with inlining :

[[email protected] /data/users/lsakka/pytorch/pytorch (fix_op_count)]$ TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 TORCH_LOGS="recompiles" python test/dynamo/tes
t_skip_non_tensor.py -k test_do_not_skip_side_effects
V0529 16:36:30.551000 139731447994176 torch/_dynamo/guards.py:2598] [1/1] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.551000 139731447994176 torch/_dynamo/guards.py:2598] [1/1] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.551000 139731447994176 torch/_dynamo/guards.py:2598] [1/1] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.561000 139731447994176 torch/_dynamo/guards.py:2598] [1/2] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.561000 139731447994176 torch/_dynamo/guards.py:2598] [1/2] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.561000 139731447994176 torch/_dynamo/guards.py:2598] [1/2] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.561000 139731447994176 torch/_dynamo/guards.py:2598] [1/2] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.572000 139731447994176 torch/_dynamo/guards.py:2598] [1/3] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.572000 139731447994176 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.572000 139731447994176 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.572000 139731447994176 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.572000 139731447994176 torch/_dynamo/guards.py:2598] [1/3] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.582000 139731447994176 torch/_dynamo/guards.py:2598] [1/4] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.656000 139731447994176 torch/_dynamo/guards.py:2598] [1/5] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - G['_variable'] == 1                                         
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - L['self'].mode == 3                                           # _dynamo/output_graph.py:448 in init_ambient_guards
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.677000 139731447994176 torch/_dynamo/guards.py:2598] [1/6] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - L['self'].mode == 3                                           # _dynamo/output_graph.py:448 in init_ambient_guards
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.701000 139731447994176 torch/_dynamo/guards.py:2598] [1/7] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles] Recompiling function forward in /data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     triggered by the following guard failure(s):
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - L['self'].mode == 4                                           # _dynamo/output_graph.py:448 in init_ambient_guards
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - L['self'].mode == 3                                           # _dynamo/output_graph.py:448 in init_ambient_guards
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - L['self'].mode == 2                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - G['_variable'] == 0                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - L['self'].mode == 1                                         
V0529 16:36:30.730000 139731447994176 torch/_dynamo/guards.py:2598] [1/8] [__recompiles]     - G['_variable'] == 0                                         
W0529 16:36:30.730000 139731447994176 torch/_dynamo/convert_frame.py:752] [1/8] torch._dynamo hit config.cache_size_limit (8)
W0529 16:36:30.730000 139731447994176 torch/_dynamo/convert_frame.py:752] [1/8]    function: 'forward' (/data/users/lsakka/pytorch/pytorch/test/dynamo/test_skip_non_tensor.py:37)
W0529 16:36:30.730000 139731447994176 torch/_dynamo/convert_frame.py:752] [1/8]    last reason: G['_variable'] == 0                                         
W0529 16:36:30.730000 139731447994176 torch/_dynamo/convert_frame.py:752] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0529 16:36:30.730000 139731447994176 torch/_dynamo/convert_frame.py:752] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.

from pytorch.

laithsakka avatar laithsakka commented on July 2, 2024

the unit test code is :

        for mode in range(1, 6):
            _variable = 0
            _variable_2 = 0

            mod = MyModule(mode=mode)
            model = torch._dynamo.optimize(backend="eager", nopython=mode != 6)(mod)
            assert _variable == 0
            assert _variable_2 == 0

            model(torch.tensor([1]))
            assert _variable == 1
            assert _variable_2 == 0

            model(torch.tensor([1]))
            assert _variable == 2
            assert _variable_2 == 0

class MyModule(torch.nn.Module):
    def __init__(self, mode: int):
        super().__init__()
        self.mode = mode
        self.register_forward_pre_hook(self.pre_forward, with_kwargs=True)

    def pre_forward(self, module, args, kwargs):
        if self.mode == 5:
            if user_function():
                global _variable
                _variable += 1
        return args, kwargs

    def forward(self, x):
        global _variable, _variable_2

        if self.mode == 1:
            if torch._utils.is_compiling():
                _variable += 1
            else:
                _variable_2 += 1
        elif self.mode == 2:
            if user_function():
                _variable += 1
        elif self.mode == 3:
            lambda_f = lambda: torch._utils.is_compiling()  # noqa: E731
            if lambda_f():
                _variable += 1
        elif self.mode == 4:
            for cond in user_generator():
                if cond:
                    _variable += 1
        elif self.mode == 5:
            x += 1
        elif self.mode == 6:
            if user_function():
                torch._dynamo.graph_break()
                _variable += 1
        return x

from pytorch.

laithsakka avatar laithsakka commented on July 2, 2024

One thing that I am trying to understand that is not related to the issue above is if its legitimate not to recompile in iterations 5 and 6 (it sounds legitimate).
A) iteration 5: _variable is not added as guard for mode 5(the increment happens in pre_forward and seems that guards in pre_forward are not captured (looks like its considered as different frame, and guard checking happens on forward after pre_forward is already called).
B) iteration 6 we also do not guard on variable probably due to torch._dynamo.graph_break()? I am unsure
why let but its not related to the issue so i will ask about this later.

  elif self.mode == 5:
            x += 1
  elif self.mode == 6:
            if user_function():
                torch._dynamo.graph_break()
                _variable += 1

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.