Git Product home page Git Product logo

santa's Introduction

Santa: Unpaired Image-to-Image Translation With Shortest Path Regularization (CVPR2023)

Abstract

Unpaired image-to-image translation aims to learn proper mappings that can map images from one domain to another domain while preserving the content of the input image. However, with large enough capacities, the network can learn to map the inputs to any random permutation of images in another domain. Existing methods treat two domains as discrete and propose different assumptions to address this problem. In this paper, we start from a different perspective and consider the paths connecting the two domains. We assume that the optimal path length between the input and output image should be the shortest among all possible paths. Based on this assumption, we propose a new method to allow generating images along the path and present a simple way to encourage the network to find the shortest path without pair information. Extensive experiments on various tasks demonstrate the superiority of our approach.

aa

Basic Usage

  • Training:
python train.py --dataroot=datasets/cityscapes --direction=BtoA --lambda_path=0.1 --tag=santa 
  • Test: put the trained checkpoints to the folder checkpoints/cityscapes
python test.py --dataroot=datasets/cityscapes --name=cityscapes --direction=BtoA
  • Hyper-parameters The default hyper-parameters should lead to good results. If you want better performance, try play with --lambda_path, --path_layers, --path_interval_min and --path_interval_max.

Pretrained Models

Dataset

The dataset is constructed from the UTKFace dataset. Then I apply super-resolution model and divide the output images into old/young according to the age. The dataset contains 1500 training and 500 testing images for each domain.

Following shows the first six training images in each domain.

aa

Citation

If you use this code for your research, please consider citing our paper:

@inproceedings{xie2023unpaired,
  title={Unpaired Image-to-Image Translation With Shortest Path Regularization},
  author={Xie, Shaoan and Xu, Yanwu and Gong, Mingming and Zhang, Kun},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={10177--10187},
  year={2023}
}

santa's People

Contributors

mid-push avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

santa's Issues

TypeError: SANTAModel.translate() takes 2 positional arguments but 3 were given

How to handle this problem?

(epoch: 5, iters: 5000, time: 0.348, data: 0.002) G_GAN: 0.222 D_real: 0.411 D_fake: 0.198 G_rec: 0.092 G_idt: 0.042 G_kl: 1.145 G_path: 0.391 d1: 0.402 d2: 0.252 energy_0: 0.002 energy_3: 0.007 energy_6: 0.111 energy_10: 0.198 energy_14: 1.638 
saving the latest model (epoch 5, total_iters 25000)
santa/FLIR_AtoB/lam0.1_layers0,3,6,10,14_dim8_rec5_idt5.0_pool0_noise1.0_kl0.01
saving the model at the end of epoch 5, iters 25000
Traceback (most recent call last):
  File "/home/customer/Desktop/LT/gan/santa-main/train.py", line 86, in <module>
    results = eval_loader(model, test_loader_a, test_loader_b, opt.run_dir, opt)
  File "/root/usr/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/customer/Desktop/LT/gan/santa-main/models/utils.py", line 88, in eval_loader
    fake = model.translate(data['A'].cuda(), acc_data['A'].cuda())
  File "/root/usr/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: SANTAModel.translate() takes 2 positional arguments but 3 were given

The train process error

First of all, thank you for your outstanding work. When I was training, I encountered the following error:
Traceback (most recent call last):
File "train.py", line 18, in
fix_b = torch.stack([test_loader_b.dataset[i]['A'] for i in range(opt.display_size)]).cuda()
File "train.py", line 18, in
fix_b = torch.stack([test_loader_b.dataset[i]['A'] for i in range(opt.display_size)]).cuda()
File "E:\Github project\santa-main\data\single_dataset.py", line 33, in getitem
A_path = self.A_paths[index]
IndexError: list index out of range

The training command I used was ’python train.py --dataroot=datasets/multidataset --direction=AtoB --lambda_path=0.1 --tag=santa‘

There are four sub folders under the multidataset folder: trainA, trainB, testA, and testB. The images in the folder are unaligned.

Can you tell me what the problem is?

training error while setting the batch size more than 1

hello, thanks for your great work.
i am trying to run your code on my own dataset, but there is an error. when i set batch size to a number greater than 1, the training process crashes, whereas, when i set it to 1 it goes fine.

do you know what is the problem?

Validation error

Error found during validation

saving the latest model (epoch 4, total_iters 5000)
santa/lotus_us_BtoA/lam0.1_layers0,3,6,10,14_dim8_rec5_idt5.0_pool0_noise1.0_kl0.01
(epoch: 4, iters: 858, time: 0.057, data: 0.001) G_GAN: 0.492 D_real: 0.166 D_fake: 0.182 G_rec: 0.063 G_idt: 0.089 G_kl: 1.086 G_path: 0.287 d1: 0.677 d2: 0.487 energy_0: 0.014 energy_3: 0.097 energy_6: 0.327 energy_10: 0.099 energy_14: 0.900
(epoch: 4, iters: 958, time: 0.057, data: 0.001) G_GAN: 0.407 D_real: 0.171 D_fake: 0.218 G_rec: 0.054 G_idt: 0.033 G_kl: 1.418 G_path: 0.226 d1: 0.820 d2: 0.658 energy_0: 0.014 energy_3: 0.116 energy_6: 0.328 energy_10: 0.105 energy_14: 0.566
(epoch: 4, iters: 1058, time: 0.057, data: 0.001) G_GAN: 0.344 D_real: 0.301 D_fake: 0.129 G_rec: 0.056 G_idt: 0.058 G_kl: 1.502 G_path: 0.176 d1: 0.525 d2: 0.371 energy_0: 0.014 energy_3: 0.107 energy_6: 0.362 energy_10: 0.105 energy_14: 0.292
(epoch: 4, iters: 1158, time: 0.057, data: 0.001) G_GAN: 0.407 D_real: 0.256 D_fake: 0.160 G_rec: 0.048 G_idt: 0.036 G_kl: 1.273 G_path: 0.161 d1: 0.188 d2: 0.039 energy_0: 0.016 energy_3: 0.119 energy_6: 0.362 energy_10: 0.103 energy_14: 0.207
(epoch: 4, iters: 1258, time: 0.057, data: 0.001) G_GAN: 0.605 D_real: 0.082 D_fake: 0.069 G_rec: 0.052 G_idt: 0.038 G_kl: 1.347 G_path: 0.389 d1: 0.986 d2: 0.884 energy_0: 0.018 energy_3: 0.198 energy_6: 0.760 energy_10: 0.137 energy_14: 0.831
(epoch: 4, iters: 1358, time: 0.057, data: 0.001) G_GAN: 0.649 D_real: 0.287 D_fake: 0.209 G_rec: 0.046 G_idt: 0.045 G_kl: 1.388 G_path: 0.317 d1: 0.922 d2: 0.772 energy_0: 0.016 energy_3: 0.171 energy_6: 0.618 energy_10: 0.150 energy_14: 0.631
End of epoch 4 / 400 Time Taken: 82 sec
learning rate = 0.0002000
(epoch: 5, iters: 44, time: 0.057, data: 0.001) G_GAN: 0.399 D_real: 0.188 D_fake: 0.151 G_rec: 0.046 G_idt: 0.045 G_kl: 1.519 G_path: 0.190 d1: 0.161 d2: 0.000 energy_0: 0.016 energy_3: 0.111 energy_6: 0.393 energy_10: 0.127 energy_14: 0.305
(epoch: 5, iters: 144, time: 0.057, data: 0.001) G_GAN: 0.418 D_real: 0.354 D_fake: 0.280 G_rec: 0.051 G_idt: 0.038 G_kl: 1.374 G_path: 0.198 d1: 0.237 d2: 0.060 energy_0: 0.016 energy_3: 0.132 energy_6: 0.455 energy_10: 0.137 energy_14: 0.252
(epoch: 5, iters: 244, time: 0.057, data: 0.001) G_GAN: 0.291 D_real: 0.194 D_fake: 0.380 G_rec: 0.035 G_idt: 0.044 G_kl: 1.517 G_path: 0.255 d1: 0.396 d2: 0.240 energy_0: 0.017 energy_3: 0.134 energy_6: 0.423 energy_10: 0.158 energy_14: 0.545
(epoch: 5, iters: 344, time: 0.057, data: 0.001) G_GAN: 0.346 D_real: 0.205 D_fake: 0.207 G_rec: 0.047 G_idt: 0.030 G_kl: 1.311 G_path: 0.199 d1: 0.384 d2: 0.271 energy_0: 0.015 energy_3: 0.122 energy_6: 0.334 energy_10: 0.130 energy_14: 0.395
(epoch: 5, iters: 444, time: 0.057, data: 0.001) G_GAN: 0.343 D_real: 0.340 D_fake: 0.134 G_rec: 0.053 G_idt: 0.043 G_kl: 1.672 G_path: 0.280 d1: 0.786 d2: 0.633 energy_0: 0.015 energy_3: 0.129 energy_6: 0.433 energy_10: 0.157 energy_14: 0.666
(epoch: 5, iters: 544, time: 0.057, data: 0.001) G_GAN: 0.412 D_real: 0.159 D_fake: 0.238 G_rec: 0.057 G_idt: 0.044 G_kl: 1.481 G_path: 0.177 d1: 0.273 d2: 0.113 energy_0: 0.015 energy_3: 0.092 energy_6: 0.274 energy_10: 0.136 energy_14: 0.370
(epoch: 5, iters: 644, time: 0.057, data: 0.001) G_GAN: 0.410 D_real: 0.121 D_fake: 0.418 G_rec: 0.040 G_idt: 0.027 G_kl: 1.325 G_path: 0.220 d1: 0.187 d2: 0.047 energy_0: 0.014 energy_3: 0.092 energy_6: 0.284 energy_10: 0.151 energy_14: 0.559
(epoch: 5, iters: 744, time: 0.057, data: 0.001) G_GAN: 0.265 D_real: 0.302 D_fake: 0.245 G_rec: 0.040 G_idt: 0.056 G_kl: 1.390 G_path: 0.331 d1: 0.837 d2: 0.703 energy_0: 0.014 energy_3: 0.145 energy_6: 0.588 energy_10: 0.154 energy_14: 0.754
(epoch: 5, iters: 844, time: 0.057, data: 0.001) G_GAN: 0.487 D_real: 0.140 D_fake: 0.250 G_rec: 0.054 G_idt: 0.078 G_kl: 1.474 G_path: 0.368 d1: 0.398 d2: 0.211 energy_0: 0.015 energy_3: 0.094 energy_6: 0.308 energy_10: 0.148 energy_14: 1.276
(epoch: 5, iters: 944, time: 0.057, data: 0.001) G_GAN: 0.258 D_real: 0.215 D_fake: 0.233 G_rec: 0.052 G_idt: 0.068 G_kl: 1.583 G_path: 0.251 d1: 0.726 d2: 0.529 energy_0: 0.015 energy_3: 0.111 energy_6: 0.358 energy_10: 0.152 energy_14: 0.619
(epoch: 5, iters: 1044, time: 0.057, data: 0.001) G_GAN: 0.585 D_real: 0.095 D_fake: 0.179 G_rec: 0.049 G_idt: 0.032 G_kl: 1.484 G_path: 0.327 d1: 0.753 d2: 0.610 energy_0: 0.013 energy_3: 0.130 energy_6: 0.434 energy_10: 0.212 energy_14: 0.846
(epoch: 5, iters: 1144, time: 0.057, data: 0.001) G_GAN: 0.445 D_real: 0.186 D_fake: 0.252 G_rec: 0.051 G_idt: 0.030 G_kl: 1.357 G_path: 0.208 d1: 0.984 d2: 0.819 energy_0: 0.013 energy_3: 0.108 energy_6: 0.325 energy_10: 0.157 energy_14: 0.435
(epoch: 5, iters: 1244, time: 0.057, data: 0.001) G_GAN: 0.487 D_real: 0.263 D_fake: 0.048 G_rec: 0.038 G_idt: 0.037 G_kl: 1.577 G_path: 0.239 d1: 0.797 d2: 0.629 energy_0: 0.013 energy_3: 0.097 energy_6: 0.327 energy_10: 0.130 energy_14: 0.627
(epoch: 5, iters: 1344, time: 0.057, data: 0.001) G_GAN: 0.266 D_real: 0.465 D_fake: 0.124 G_rec: 0.033 G_idt: 0.047 G_kl: 1.320 G_path: 0.143 d1: 0.654 d2: 0.502 energy_0: 0.012 energy_3: 0.094 energy_6: 0.294 energy_10: 0.112 energy_14: 0.205
saving the model at the end of epoch 5, iters 7070
Traceback (most recent call last):
File "santa/train.py", line 86, in
results = eval_loader(model, test_loader_a, test_loader_b, opt.run_dir, opt)
File "torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "C1_ML_Analysis/src/santa/models/utils.py", line 88, in eval_loader
fake = model.translate(data['A'].cuda(), acc_data['A'].cuda())
File "torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
TypeError: SANTAModel.translate() takes 2 positional arguments but 3 were given

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.