Git Product home page Git Product logo

Comments (13)

jbegaint avatar jbegaint commented on August 22, 2024 2

So, this line optimizer.load_state_dict((checkpoint["optimizer"])) results in the RuntimeError. It has to do with the use of set in the optimizer definitions. Since sets are not sorted, the indexing of the parameter info in the optimizer state dict is different between initializations. Took me a while to figure out the issue since internally I coded it correctly ;-).
I'll push a fix soon.

Thanks again for reporting!

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024 1

Hi Jean,
Sorry to reopen this issue. I totally agree with you that the problem comes from the use of set in the parameter definition. So, why not getting rid of it completely? As you mentioned, set is a bit tricky because they are not sorted. I was checking several other PyTorch code examples and all of them avoid using set for returning parameters when they want to give it to the optimizer.

The use of sorted to fix the issue might be a bit risky as well because if one day PyTorch changes something internally (which I agree is very unlikely to change this part) your code might break because as I said they assume people are using conventions.

I understand that this is a matter of taste, but if it is fine for you, I think the following is more understandable for others and complies with other PyTorch codes in the community:



 def configure_optimizers(net, args):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    # Use list of tuples instead of dict to be able to later check the elements are unique and there is no intersection
    parameters = [(n,p) for n, p in net.named_parameters() if not n.endswith(".quantiles")]
    aux_parameters = [(n,p) for n, p in net.named_parameters() if n.endswith(".quantiles")]

    # Make sure we don't have an intersection of parameters
    parameters_name_set = set(n for n,p in parameters)
    aux_parameters_name_set = set(n for n, p in aux_parameters)
    assert len(parameters) == len(parameters_name_set)
    assert len(aux_parameters) == len(aux_parameters_name_set)

    inter_params = parameters_name_set & aux_parameters_name_set
    union_params = parameters_name_set | aux_parameters_name_set

    assert len(inter_params) == 0
    assert len(union_params) - len(dict(net.named_parameters()).keys()) == 0

    optimizer = optim.Adam(
        (p for (n, p) in parameters if p.requires_grad),
        lr=args.learning_rate,
    )
    aux_optimizer = optim.Adam(
        (p for (n, p) in aux_parameters if p.requires_grad),
        lr=args.aux_learning_rate,
    )
    return optimizer, aux_optimizer

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024 1

Hi Navid, i'll keep our current implementation as it's easier to maintain w.r.t to our internal codebase. Also our intention was just to provide an example training for people to get started and take it from here.

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024

Hi Navid, thanks for the report. I can't reproduce the bug locally. Can you give me the commands you used? Also can you make sure you have the latest version installed locally. Thanks,

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024

Hi Jean.
I just installed everything again so that I am sure that I am using the latest version. I still get the same error. Here are the commands I used:

step 1) run the following command for 1-2 epochs (just to save the checkpoint file) and then cancel the program (to not run until the last epoch)

python compressai_train.py -m bmshj2018-hyperprior -d /home/navid/data --save

step 2) run the following command

python compressai_train.py -m bmshj2018-hyperprior -d /home/navid/data --save --checkpoint-file /home/navid/code/checkpoint_best_loss.pth.tar

This is the complete error

Traceback (most recent call last):
  File "compressai_train.py", line 354, in <module>
    main(sys.argv[1:])
  File "compressai_train.py", line 321, in main
    train_one_epoch(
  File "compressai_train.py", line 129, in train_one_epoch
    optimizer.step()
  File "/home/navid/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/navid/venv/lib/python3.8/site-packages/torch/optim/adam.py", line 108, in step
    F.adam(params_with_grad,
  File "/home/navid/venv/lib/python3.8/site-packages/torch/optim/functional.py", line 86, in adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: output with shape [128, 3, 1] doesn't match the broadcast shape [128, 3, 3]

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024

Hi Navid, any chance the checkpoint you are using was created with a prior version of compressai?

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024

No. I just created them right now

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024

All you need to do is to run the same script twice. Once without --checkpoint-file and the other time with --checkpoint-file

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024

yes I did, no issue one my side. I'll keep investigating

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024

ok great, managed to reproduce it. thanks i'll fix this

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024

Thank you very much Jean. Eagerly looking forward to seeing the bug fixed :).

By the way, as a side note, if you agree you can replace your CompressAI/examples/train.py with the above code. It is essentially your code, I just added the checkpoint argument and its related code to be able to continue from a stored checkpoint. If you agreed to put, feel free to change the variable names as you want.

from compressai.

jbegaint avatar jbegaint commented on August 22, 2024

Yes, I updated the code with your improvements, thanks :-). I'll merge this soon, just waiting for our internal tests to complete.

from compressai.

navid-mahmoudian avatar navid-mahmoudian commented on August 22, 2024

I used the list of tuples for parameters and aux_parameters in purpose because

  • I want to keep exactly the same order that is returned by net.named_parameters() i.e. parameters themselves plus their name for being able to check the intersection. For example, using dict is not appropriate because also acts kindly similar to set, i.e. if there are several parameters (value) with the same name (key) it keeps the last one. So not good for intersection checking.

  • As you did, I want to check intersection based on their names (checking intersection of names rather than the parameter value)

from compressai.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.