Git Product home page Git Product logo

Comments (10)

bdhirsh avatar bdhirsh commented on June 4, 2024

@warner-benjamin I ran locally with a nightly and this actually passes for me. Can you try out a nightly? https://pytorch.org/get-started/locally/

from pytorch.

warner-benjamin avatar warner-benjamin commented on June 4, 2024

@bdhirsh I tested my replication script with yesterday's nightly and 2.3. You can see my environment in the "PyTorch Nightly Environment" section. These errors are only with DDP. Single GPU compiles and trains without issue.

I installed today's nightly pytorch-2.4.0.dev20240507 with both Cuda 12.4 & 12.1. Setting dynamic shapes with DDP via torch._dynamo.mark_dynamic using the following command still errors out with the same ConstraintViolationError.

torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

And setting torch.compile(..., dynamic=True) or torch.compile(..., dynamic=None) using the following command still results with recompilations every batch until the torch._dynamo hit config.cache_size_limit (8) is hit.

# torch.compile(..., dynamic=True)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --dynamic_true

# torch.compile(..., dynamic=None)
torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen

from pytorch.

algal avatar algal commented on June 4, 2024

I am seeing the same issue this morning, running the same three commands on the replication script, on my system using CUDA 12.1.

Details below:

PyTorch Nightly Environment details
Collecting environment information...
PyTorch version: 2.4.0.dev20240507
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.1.0-20-amd64-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090

Nvidia driver version: 525.147.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               16
On-line CPU(s) list:                  0-15
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 7 5700G with Radeon Graphics
CPU family:                           25
Model:                                80
Thread(s) per core:                   2
Core(s) per socket:                   8
Socket(s):                            1
Stepping:                             0
Frequency boost:                      enabled
CPU(s) scaling MHz:                   64%
CPU max MHz:                          4672.0698
CPU min MHz:                          1400.0000
BogoMIPS:                             7585.74
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                       AMD-V
L1d cache:                            256 KiB (8 instances)
L1i cache:                            256 KiB (8 instances)
L2 cache:                             4 MiB (8 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-15
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.4.0.dev20240507
[pip3] torchvision==0.19.0.dev20240507
[pip3] triton==3.0.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
[conda] libopenvino-pytorch-frontend 2024.0.0             he02047a_5    conda-forge
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     1.26.4          py311h64a7726_0    conda-forge
[conda] pillow                    9.3.0           py311h3fd9d12_2    pytorch-nightly
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.4.0.dev20240507 py3.11_cuda12.1_cudnn8.9.2_0    pytorch-nightly
[conda] pytorch-cuda              12.1                 ha16c6d3_6    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] requests                  2.28.1                  py311_0    pytorch-nightly
[conda] torchtriton               3.0.0+45fff310c8           py311    pytorch-nightly
[conda] torchvision               0.19.0.dev20240507     py311_cu121    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

from pytorch.

ezyang avatar ezyang commented on June 4, 2024

I'm going to look into this. But my recollection is that HF added some error checking code which forces specialization, and I haven't gotten around to yelling at them to stop running this logic when being torch compiled.

BTW, the two errors here are one and the same. mark_dynamic is yelling at you because it tried to make it dynamic, but failed due to specialization. You can use TORCH_LOGS=dynamic to find out where the specialization happened.

from pytorch.

warner-benjamin avatar warner-benjamin commented on June 4, 2024

I'm going to look into this. But my recollection is that HF added some error checking code which forces specialization, and I haven't gotten around to yelling at them to stop running this logic when being torch compiled.

It's not just HF models which trigger this when using DDP. My replication script uses a simple two-layer model with an Embedding and Linear layer. One layer doesn't replicate this issue. It seems to have something to do with adding a second layer.

class EmbedHeadModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int):
        super().__init__()
        self.vocab_embed = nn.Embedding(vocab_size, hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x: Tensor):
        out = self.vocab_embed(x)
        out = self.head(out)
        return out

BTW, the two errors here are one and the same. mark_dynamic is yelling at you because it tried to make it dynamic, but failed due to specialization. You can use TORCH_LOGS=dynamic to find out where the specialization happened.

When I run my replication script with TORCH_LOGS=+dynamic

TORCH_LOGS=+dynamic torchrun --nproc_per_node=2 replication.py --ddp --compile --variable_seqlen --use_mark_dynamic

I get the following output for rank 0:

TORCH_LOGS=+dynamic Rank 0 Output
torch/fx/experimental/symbolic_shapes.py:2268] [0/0] create_env
torch/fx/experimental/symbolic_shapes.py:3239] [0/0] create_symbol s0 = 977 for L['x'].size()[1] [2, 9223372036854775806] at test/replication.py:52 in forward (_dynamo/variables/builder.py:2137 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(16*s0, 16) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(Mod(16, 16*s0), 0) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 2048*s0 > 2048 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(s0, 1) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(16*s0, 16) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(s0, 1) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval s0 > 1 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4634] [0/0] eval 32768*s0 < 2147483648 [guard added] (_inductor/codegen/triton.py:3409 in can_use_32bit_indexing), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="32768*s0 < 2147483648"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 16*s0 < 2147483648 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]
torch/fx/experimental/symbolic_shapes.py:3326] [0/0] produce_guards
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[0] 16 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[1] 977 RelaxedUnspecConstraint(warn_only=False)
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[0] 977 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[1] 1 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].storage_offset() 0 None
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[0] == 16
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[1] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[0] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[1] == 1
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].storage_offset() == 0
torch/_guards.py:261] [0/0] Error while creating guard:
torch/_guards.py:261] [0/0] Name: ''
torch/_guards.py:261] [0/0]     Source: shape_env
torch/_guards.py:261] [0/0]     Create Function: SHAPE_ENV
torch/_guards.py:261] [0/0]     Guard Types: None
torch/_guards.py:261] [0/0]     Code List: None
torch/_guards.py:261] [0/0]     Object Weakref: None
torch/_guards.py:261] [0/0]     Guarded Class Weakref: None
torch/_guards.py:261] [0/0] Traceback (most recent call last):
torch/_guards.py:261] [0/0]   File "torch/_guards.py", line 259, in create
torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).
torch/_guards.py:263] [0/0] Created at:
torch/_guards.py:263] [0/0]   File "torch/_dynamo/convert_frame.py", line 499, in transform
torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/symbolic_convert.py", line 2143, in __init__
torch/_guards.py:263] [0/0]     output=OutputGraph(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 308, in __init__
torch/_guards.py:263] [0/0]     self.init_ambient_guards()
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 447, in init_ambient_guards
torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats evaluate_expr: CacheInfo(hits=342, misses=15, maxsize=256, currsize=15)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _find: CacheInfo(hits=31, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats has_hint: CacheInfo(hits=1, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats size_hint: CacheInfo(hits=1, misses=3, maxsize=256, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats simplify: CacheInfo(hits=6, misses=17, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats replace: CacheInfo(hits=3257, misses=58, maxsize=None, currsize=18)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=1, misses=21, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_implications: CacheInfo(hits=1, misses=1, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_axioms: CacheInfo(hits=17, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats safe_expand: CacheInfo(hits=606, misses=53, maxsize=256, currsize=53)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats uninteresting_files: CacheInfo(hits=18, misses=1, maxsize=None, currsize=1)

I'm not seeing anything about specialization, but might be misinterpreting the logs.

from pytorch.

ezyang avatar ezyang commented on June 4, 2024

It's this:

torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]

Very strange though, why is this suppressed 🤔. You could get a full backtrace for this log with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)"

from pytorch.

eellison avatar eellison commented on June 4, 2024

@ezyang could it be related to this ? https://github.com/pytorch/pytorch/pull/120523/files#diff-cb8e02fc8f37e53904ab1b151c46dd109cf50d8121bbd340834b2e976b22ebc4R74

Maybe the idiom there is not correct. We're trying update the meta strides without adding guards or specializations

from pytorch.

ezyang avatar ezyang commented on June 4, 2024

Oh yeah, this looks very very naughty. Hmmmm

from pytorch.

ezyang avatar ezyang commented on June 4, 2024

As a stopgap, I guess we could prevent replacements from happening when guards are suppressed. This still seems very naughty though.....

from pytorch.

warner-benjamin avatar warner-benjamin commented on June 4, 2024

Here's the additional backtrace with the "Eq(2048*s0, 2000896)" guard added:

TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(2048*s0, 2000896)"
torch/fx/experimental/symbolic_shapes.py:2268] [0/0] create_env
torch/fx/experimental/symbolic_shapes.py:3239] [0/0] create_symbol s0 = 977 for L['x'].size()[1] [2, 9223372036854775806] at test/replication.py:52 in forward (_dynamo/variables/builder.py:2137 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval False == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(16*s0, 16) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(Mod(16, 16*s0), 0) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 2048*s0 > 2048 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(s0, 1) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Ne(16*s0, 16) == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval Eq(s0, 1) == False [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval s0 > 1 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4634] [0/0] eval 32768*s0 < 2147483648 [guard added] (_inductor/codegen/triton.py:3409 in can_use_32bit_indexing), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="32768*s0 < 2147483648"
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval 16*s0 < 2147483648 == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4719] [0/0] eval True == True [statically known]
torch/fx/experimental/symbolic_shapes.py:4394] [0/0] set_replacement s0 = 977 (solve) ValueRanges(lower=977, upper=977, is_bool=False)
torch/fx/experimental/symbolic_shapes.py:4824] [0/0] eval Eq(2048*s0, 2000896) [guard suppressed]
torch/fx/experimental/symbolic_shapes.py:3326] [0/0] produce_guards
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[0] 16 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].size()[1] 977 RelaxedUnspecConstraint(warn_only=False)
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[0] 977 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].stride()[1] 1 None
torch/fx/experimental/symbolic_shapes.py:3508] [0/0] track_symint L['x'].storage_offset() 0 None
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[0] == 16
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].size()[1] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[0] == 977
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].stride()[1] == 1
torch/fx/experimental/symbolic_shapes.py:3652] [0/0] Skipping guard L['x'].storage_offset() == 0
torch/_guards.py:261] [0/0] Error while creating guard:
torch/_guards.py:261] [0/0] Name: ''
torch/_guards.py:261] [0/0]     Source: shape_env
torch/_guards.py:261] [0/0]     Create Function: SHAPE_ENV
torch/_guards.py:261] [0/0]     Guard Types: None
torch/_guards.py:261] [0/0]     Code List: None
torch/_guards.py:261] [0/0]     Object Weakref: None
torch/_guards.py:261] [0/0]     Guarded Class Weakref: None
torch/_guards.py:261] [0/0] Traceback (most recent call last):
torch/_guards.py:261] [0/0]   File "torch/_guards.py", line 259, in create
torch/_guards.py:261] [0/0]     return self.create_fn(builder, self)
torch/_guards.py:261] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
torch/_guards.py:261] [0/0]     guards = output_graph.shape_env.produce_guards(
torch/_guards.py:261] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch/_guards.py:261] [0/0]   File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
torch/_guards.py:261] [0/0]     raise ConstraintViolationError(
torch/_guards.py:261] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
torch/_guards.py:261] [0/0]   - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).
torch/_guards.py:263] [0/0] Created at:
torch/_guards.py:263] [0/0]   File "torch/_dynamo/convert_frame.py", line 499, in transform
torch/_guards.py:263] [0/0]     tracer = InstructionTranslator(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/symbolic_convert.py", line 2143, in __init__
torch/_guards.py:263] [0/0]     output=OutputGraph(
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 308, in __init__
torch/_guards.py:263] [0/0]     self.init_ambient_guards()
torch/_guards.py:263] [0/0]   File "torch/_dynamo/output_graph.py", line 447, in init_ambient_guards
torch/_guards.py:263] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "replication.py", line 139, in <module>
    train()
  File "replication.py", line 127, in train
    output = model(data)
             ^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/parallel/distributed.py", line 1620, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/parallel/distributed.py", line 1438, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 977, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 822, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 410, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/benja/.conda/envs/torchnight/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 703, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 660, in compile_inner
    check_fn = CheckFunctionManager(
               ^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/guards.py", line 2086, in __init__
    guard.create(builder)
  File "torch/_guards.py", line 259, in create
    return self.create_fn(builder, self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/guards.py", line 1683, in SHAPE_ENV
    guards = output_graph.shape_env.produce_guards(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/fx/experimental/symbolic_shapes.py", line 3830, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['x'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['x'].size()[1]) are valid because L['x'].size()[1] was inferred to be a constant (977).


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats evaluate_expr: CacheInfo(hits=342, misses=15, maxsize=256, currsize=15)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _find: CacheInfo(hits=31, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats has_hint: CacheInfo(hits=1, misses=2, maxsize=256, currsize=2)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats size_hint: CacheInfo(hits=1, misses=3, maxsize=256, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats simplify: CacheInfo(hits=6, misses=17, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats replace: CacheInfo(hits=3257, misses=58, maxsize=None, currsize=18)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=1, misses=21, maxsize=None, currsize=3)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_implications: CacheInfo(hits=1, misses=1, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats get_axioms: CacheInfo(hits=17, misses=3, maxsize=None, currsize=1)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats safe_expand: CacheInfo(hits=606, misses=53, maxsize=256, currsize=53)
torch/fx/experimental/symbolic_shapes.py:110] lru_cache_stats uninteresting_files: CacheInfo(hits=18, misses=1, maxsize=None, currsize=1)

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.