Git Product home page Git Product logo

Comments (2)

BlackSamorez avatar BlackSamorez commented on September 25, 2024

Full error:


RuntimeError                              Traceback (most recent call last)

[<ipython-input-5-388f70847b39>](https://localhost:8080/#) in <cell line: 24>()
     22 )
     23 model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
---> 24 trainer.train()

37 frames

[/usr/local/lib/python3.10/dist-packages/transformers/](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1857                 hf_hub_utils.enable_progress_bars()
   1858         else:
-> 1859             return inner_training_loop(
   1860                 args=args,
   1861                 resume_from_checkpoint=resume_from_checkpoint,

[/usr/local/lib/python3.10/dist-packages/transformers/](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2202                 with self.accelerator.accumulate(model):
-> 2203                     tr_loss_step = self.training_step(model, inputs)
   2205                 if (

[/usr/local/lib/python3.10/dist-packages/transformers/](https://localhost:8080/#) in training_step(self, model, inputs)
   3137         with self.compute_loss_context_manager():
-> 3138             loss = self.compute_loss(model, inputs)
   3140         if self.args.n_gpu > 1:

[/usr/local/lib/python3.10/dist-packages/transformers/](https://localhost:8080/#) in compute_loss(self, model, inputs, return_outputs)
   3159         else:
   3160             labels = None
-> 3161         outputs = model(**inputs)
   3162         # Save past state if it exists
   3163         # TODO: this needs to be fixed and made cleaner later.

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/](https://localhost:8080/#) in forward(*args, **kwargs)
    824     def forward(*args, **kwargs):
--> 825         return model_forward(*args, **kwargs)
    827     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    812     def __call__(self, *args, **kwargs):
--> 813         return convert_to_fp32(self.model_forward(*args, **kwargs))
    815     def __getstate__(self):

[/usr/local/lib/python3.10/dist-packages/torch/amp/](https://localhost:8080/#) in decorate_autocast(*args, **kwargs)
     14     def decorate_autocast(*args, **kwargs):
     15         with autocast_instance:
---> 16             return func(*args, **kwargs)
     18     decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode"  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/peft/](https://localhost:8080/#) in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1302             with self._enable_peft_forward_hooks(**kwargs):
   1303                 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1304                 return self.base_model(
   1305                     input_ids=input_ids,
   1306                     attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/peft/tuners/](https://localhost:8080/#) in forward(self, *args, **kwargs)
    178     def forward(self, *args: Any, **kwargs: Any):
--> 179         return self.model.forward(*args, **kwargs)
    181     def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
   1358         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1359         outputs = self.model(
   1360             input_ids=input_ids,
   1361             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
   1215             if self.gradient_checkpointing and
-> 1216                 layer_outputs = self._gradient_checkpointing_func(
   1217                     decoder_layer.__call__,
   1218                     hidden_states,

[/usr/local/lib/python3.10/dist-packages/torch/](https://localhost:8080/#) in inner(*args, **kwargs)
     22             import torch._dynamo
---> 24             return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
     26         return inner

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/](https://localhost:8080/#) in _fn(*args, **kwargs)
    487                 dynamo_config_ctx.__enter__()
    488             try:
--> 489                 return fn(*args, **kwargs)
    490             finally:
    491                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/](https://localhost:8080/#) in inner(*args, **kwargs)
     15     @functools.wraps(fn)
     16     def inner(*args, **kwargs):
---> 17         return fn(*args, **kwargs)
     19     return inner

[/usr/local/lib/python3.10/dist-packages/torch/utils/](https://localhost:8080/#) in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    480                 "use_reentrant=False."
    481             )
--> 482         return CheckpointFunction.apply(function, preserve, *args)
    483     else:
    484         gen = _checkpoint_without_reentrant_generator(

[/usr/local/lib/python3.10/dist-packages/torch/autograd/](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    551             # See NOTE: [functorch vjp and autograd interaction]
    552             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553             return super().apply(*args, **kwargs)  # type: ignore[misc]
    555         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/torch/utils/](https://localhost:8080/#) in forward(ctx, run_function, preserve_rng_state, *args)
    260         with torch.no_grad():
--> 261             outputs = run_function(*args)
    262         return outputs

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)
    943         residual = hidden_states
    944         hidden_states = self.post_attention_layernorm(hidden_states)
--> 945         hidden_states, router_logits = self.block_sparse_moe(hidden_states)
    946         hidden_states = residual + hidden_states

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/](https://localhost:8080/#) in forward(self, hidden_states)
    873             # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
    874             current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
--> 875             current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
    877             # However `index_add_` only support torch tensors for indexing so we'll use

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/](https://localhost:8080/#) in forward(self, hidden_states)
    802     def forward(self, hidden_states):
--> 803         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
    804         current_hidden_states = self.w2(current_hidden_states)
    805         return current_hidden_states

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1522         try:

[/usr/local/lib/python3.10/dist-packages/aqlm/](https://localhost:8080/#) in forward(self, input)
     72         if self.use_gemv_rule(input):
---> 73             return self.gemv_op.apply(input,, self.codebooks, self.scales, self.bias)
     74         else:
     75             return self.gemm_op.apply(input,, self.codebooks, self.scales, self.bias)

[/usr/local/lib/python3.10/dist-packages/torch/autograd/](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    551             # See NOTE: [functorch vjp and autograd interaction]
    552             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553             return super().apply(*args, **kwargs)  # type: ignore[misc]
    555         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/aqlm/](https://localhost:8080/#) in forward(ctx, input, codes, codebooks, scales, bias)
    114                 bias,
    115             )
--> 116             return forward_pass_kernel(
    117                 input,
    118                 codes,

[/usr/local/lib/python3.10/dist-packages/torch/](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    753         # We save the function ptr as the `op` attribute on
    754         # OpOverloadPacket to access it here.
--> 755         return self._op(*args, **(kwargs or {}))
    757     # TODO: use this to make a __dir__

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

I'm almost certain the error is raised here:

For some reason checkpointing tries to pass a tensor of size 0 through the model and this particular reshape doesn't handle it. A fix should be straightforward. I'll try and fix it soon.

from aqlm.

BlackSamorez avatar BlackSamorez commented on September 25, 2024

Should be fixed in aqlm==1.1.5

from aqlm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.