Git Product home page Git Product logo

Comments (6)

yuval-alaluf avatar yuval-alaluf commented on August 22, 2024

Where do you get this error? I would assume that it has something to do with your choice of layers_to_tune. You may have provided layers that do not exist in a generator with an output size of 256.
However, if you provide a full stack trace I could try to better guide you regarding where this error could be coming from.

from hyperstyle.

tsshubhamv avatar tsshubhamv commented on August 22, 2024

Error 1:

dataset_type="sg2-ada-smile_256"
encoder_type="SharedWeightsHyperNetResNetSeparable"
exp_dir="experiments/smile_256"
workers=8
batch_size=8
test_batch_size=4
test_workers=4
save_interval=2000
lpips_lambda=0.8
l2_lambda=0.8
id_lambda=0
n_iters_per_batch=5
max_val_batches=150
output_size=256
layers_to_tune="0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24"

cmd = f'''python scripts/train.py \
  --dataset_type={dataset_type} \
  --encoder_type={encoder_type} \
  --exp_dir={exp_dir} \
  --workers={workers} \
  --batch_size={batch_size} \
  --test_batch_size={test_batch_size} \
  --test_workers={test_workers} \
  --save_interval={save_interval} \
  --lpips_lambda={lpips_lambda} \
  --l2_lambda={l2_lambda} \
  --id_lambda={id_lambda} \
  --n_iters_per_batch={n_iters_per_batch} \
  --max_val_batches={max_val_batches} \
  --stylegan_weights={"artifacts/sg2-ada-smile_256.pt"} \
  --output_size={output_size} \
  # --load_w_encoder \
  # --w_encoder_checkpoint_path={"artifacts/e2e.pth"} \
  --layers_to_tune={layers_to_tune}'''

!{cmd}

When using above options output is size mismatch:

{'batch_size': 8,
 'board_interval': 50,
 'checkpoint_path': None,
 'dataset_type': 'sg2-ada-smile_256',
 'encoder_type': 'SharedWeightsHyperNetResNetSeparable',
 'exp_dir': 'experiments/smile_256',
 'id_lambda': 0.0,
 'image_interval': 100,
 'input_nc': 6,
 'l2_lambda': 0.8,
 'layers_to_tune': '0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24',
 'learning_rate': 0.0001,
 'load_w_encoder': False,
 'lpips_lambda': 0.8,
 'max_steps': 500000,
 'max_val_batches': 150,
 'moco_lambda': 0,
 'n_iters_per_batch': 5,
 'optim_name': 'ranger',
 'output_size': 256,
 'save_interval': 2000,
 'stylegan_weights': 'artifacts/sg2-ada-smile_256.pt',
 'test_batch_size': 4,
 'test_workers': 4,
 'train_decoder': False,
 'val_interval': 1000,
 'w_encoder_checkpoint_path': None,
 'w_encoder_type': 'WEncoder',
 'workers': 8}
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100% 83.3M/83.3M [00:00<00:00, 263MB/s]
Loading hypernet weights from resnet34!
Loading decoder weights from pretrained path: artifacts/sg2-ada-smile_256.pt
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:135: UserWarning: Using 'weights' as positional parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) instead.
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100% 233M/233M [00:02<00:00, 95.1MB/s]
Downloading: "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/master/lpips/weights/v0.1/alex.pth" to /root/.cache/torch/hub/checkpoints/alex.pth
100% 5.87k/5.87k [00:00<00:00, 15.2MB/s]
Loading dataset for sg2-ada-smile_256
Number of training samples: 2014
Number of test samples: 502
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
Traceback (most recent call last):
  File "/content/hyperstyle/scripts/train.py", line 32, in <module>
    main()
  File "/content/hyperstyle/scripts/train.py", line 20, in main
    coach.train()
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 135, in train
    x, y, y_hat, loss_dict, id_logs, w_inversion = self.perform_forward_on_batch(batch, train=True)
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 103, in perform_forward_on_batch
    y_hat, latent, weights_deltas, codes, w_inversion = self.net.forward(x,
  File "/content/hyperstyle/./models/hyperstyle.py", line 86, in forward
    images, result_latent = self.decoder([codes],
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/hyperstyle/./models/stylegan2/model.py", line 492, in forward
    styles = [self.style(s) for s in styles]
  File "/content/hyperstyle/./models/stylegan2/model.py", line 492, in <listcomp>
    styles = [self.style(s) for s in styles]
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/hyperstyle/./models/stylegan2/model.py", line 150, in forward
    out = F.linear(input, weight * self.scale)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x256 and 512x512)

from hyperstyle.

tsshubhamv avatar tsshubhamv commented on August 22, 2024

When using different or default options:

a = f'''python scripts/train.py \
  --dataset_type={dataset_type} \
  --encoder_type={encoder_type} \
  --exp_dir={exp_dir}\
  --stylegan_weights={"artifacts/sg2-ada-smile_256.pt"}
  '''
!{a}

I get the following error:

Loading hypernet weights from resnet34!
Loading decoder weights from pretrained path: artifacts/sg2-ada-smile_256.pt
Traceback (most recent call last):
  File "/content/hyperstyle/scripts/train.py", line 32, in <module>
    main()
  File "/content/hyperstyle/scripts/train.py", line 19, in main
    coach = Coach(opts)
  File "/content/hyperstyle/./training/coach_hyperstyle.py", line 35, in __init__
    self.net = HyperStyle(self.opts).to(self.device)
  File "/content/hyperstyle/./models/hyperstyle.py", line 26, in __init__
    self.load_weights()
  File "/content/hyperstyle/./models/hyperstyle.py", line 59, in load_weights
    self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "convs.12.conv.weight", "convs.12.conv.blur.kernel", "convs.12.conv.modulation.weight", "convs.12.conv.modulation.bias", "convs.12.noise.weight", "convs.12.activate.bias", "convs.13.conv.weight", "convs.13.conv.modulation.weight", "convs.13.conv.modulation.bias", "convs.13.noise.weight", "convs.13.activate.bias", "convs.14.conv.weight", "convs.14.conv.blur.kernel", "convs.14.conv.modulation.weight", "convs.14.conv.modulation.bias", "convs.14.noise.weight", "convs.14.activate.bias", "convs.15.conv.weight", "convs.15.conv.modulation.weight", "convs.15.conv.modulation.bias", "convs.15.noise.weight", "convs.15.activate.bias", "to_rgbs.6.bias", "to_rgbs.6.upsample.kernel", "to_rgbs.6.conv.weight", "to_rgbs.6.conv.modulation.weight", "to_rgbs.6.conv.modulation.bias", "to_rgbs.7.bias", "to_rgbs.7.upsample.kernel", "to_rgbs.7.conv.weight", "to_rgbs.7.conv.modulation.weight", "to_rgbs.7.conv.modulation.bias", "noises.noise_13", "noises.noise_14", "noises.noise_15", "noises.noise_16".

from hyperstyle.

tsshubhamv avatar tsshubhamv commented on August 22, 2024

@yuval-alaluf I have tried with options and default options. It would great if you can enlighten what I'm doing wrong.

from hyperstyle.

yuval-alaluf avatar yuval-alaluf commented on August 22, 2024

For the first error, RuntimeError: mat1 and mat2 shapes cannot be multiplied (6144x256 and 512x512), it seems like your images are not resized to the correct size. Try to make sure that all your input images are square-shaped.

For the second error, you need to make sure that you set --output_size=256.

Try making these changes to see if this helps solve your issues.

from hyperstyle.

tsshubhamv avatar tsshubhamv commented on August 22, 2024

I'm pretty sure the all the images are square shaped and of size 256x256. Also output_size is 256. I have tried with these parameters once I'll try it again.

from hyperstyle.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.