jizhizili / aim Goto Github PK
View Code? Open in Web Editor NEW[IJCAI'21] Deep Automatic Natural Image Matting
License: MIT License
[IJCAI'21] Deep Automatic Natural Image Matting
License: MIT License
Amazing work.
Hi @JizhiziLi
After downloading the json file, the filename for DIM and HAttMatting dataset is different from their original datasets..
Could you please provide the filename matching file?
Thank you!
like 4k 60fpsοΌ
I tested the aim-net500 dataset with the model you provided and found that the SAD was only 47. Can you check that the model you uploaded in GitHub is correct? In addition, did you do other processing on the DIM dataset? In your dim_hatt_am2k_type.json file, the key of the DIM dataset is number 1,2,3,4....οΌbut the names of DIM datasets are different. So how did you rename them. I use the training_fg_names.txt file to rename them, and retrained the model according to the steps you provided on GitHub, and the SAD is only 55.
Thank you for the great new model! It produces really cool results.
It would be very interesting to play with model and train on my own datasets.
Could you please give some estimation when will you release the training code so I don't refresh a page several times a day :)
where can i get this model? thanks!
I trained the model with duts dataset for 100 epochs(default) but the loss and accuracy seems to be low:-
INFO:root:AIM-Epoch[100/100](1/660) Lr:0.00010000 Loss:0.82403 Global:0.66851 Local:0.11801 Fusion-alpha:0.03752 Speed:4.63853s/iter Exa(h:m:s):00:50:56
INFO:root:AIM-Epoch[100/100](2/660) Lr:0.00010000 Loss:0.83574 Global:0.67843 Local:0.10950 Fusion-alpha:0.04781 Speed:3.84507s/iter Exa(h:m:s):00:42:10
INFO:root:AIM-Epoch[100/100](3/660) Lr:0.00010000 Loss:0.78562 Global:0.64933 Local:0.11512 Fusion-alpha:0.02116 Speed:3.55392s/iter Exa(h:m:s):00:38:54
INFO:root:AIM-Epoch[100/100](4/660) Lr:0.00010000 Loss:0.80663 Global:0.65181 Local:0.12981 Fusion-alpha:0.02502 Speed:3.40979s/iter Exa(h:m:s):00:37:16
INFO:root:Checkpoint saved to models/trained/aim_transfer_duts/ckpt_epoch100.pth
If our dataset contains SO only, how long should we train the network, i.e if we are not planning to train further on synthetic datasets which contains STM, NS images?
Is there a benchmark/comparison for DUTS dataset with other models?
Also, is the resnet backbone frozen during training?
Hi, Jizhizi Li π€ π€ π€.
Many thanks for this amazing repo! Works very nice for my images with clothes! π₯³ ππ
But it seems to me that the results can be further improved by additional training on individual cases. π
Do you have some plans for realization the training code?
I have already seen the previous question on this topic, but I can't wait. π π π
Thank you for the answer! Have a very very nice day π€ππ
Hi
Are you planning to share some code here? (training/inference)
Thanks
How many iterations does it do on each epochs? Form line: https://github.com/JizhiziLi/AIM/blob/master/core/train.py#L81, it looks like it does only 4 iterations for each epochs. Does this approach train the model with all images in one epoch?
Is this model compatible for conversion to CoreML/TFlite and can it run on a iOS/Mobile device?
Hi I've been testing you model quite a bit and have noticed, that some vertical images tend to rotate to the horisontal position. I wonted why and if there is a way of turning it off, since they rotate randomly clockwise or counter clockwise. An alternative would be to just analyse the images that are in horisontal position, however I'm not sure if that wouldn't have an impact on the performance.
Hello, the following error occurred when I reproduced the code. How can I solve it?
Traceback (most recent call last):
File "core/train.py", line 156, in
main()
File "core/train.py", line 150, in main
train(args, model, optimizer, train_loader, epoch)
File "core/train.py", line 78, in train
for iteration, batch in enumerate(train_loader, 1):
File "/home/anaconda/envs/wxt-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/anaconda/envs/wxt-env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/anaconda/envs/wxt-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anaconda/envs/wxt-env/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wxt/AIM-master/core/data.py", line 101, in getitem
ori, fg, bg = generate_composite(fg, bg, mask)
File "/home/wxt/AIM-master/core/util.py", line 111, in generate_composite
composite = alpha * fg + (1 - alpha) * bg
ValueError: operands could not be broadcast together with shapes (900,1600,1) (900,1600)
First: thanks for this amazing repository and research and for sharing it.
So I did test some images with the code/pre-trained model you released. I share some results below.
I have some very good results and some more disappointing ones (I am unreasonably demanding, sorry for that).
One question is that I have some differences on the 3 default sample images results from the repo, i.e. when inferencing on 1.png, 2.png and 3.png): is that due to the paper trained model vs the newly trained model differences?
Another question: do you know why sometimes there is 100% background detection, e.g. with the last two images below?
I am on cuda 10.2 and pytorch 1.9, fyi.
thanks
Hey there! Love the paper!
Can we access the pre-trained model?
Thanks!
I've test my trimap based model on AIM-500. However, the results are much better that reported in your paper. It is reasonable that trimap-based models have advantages over trimap-free ones.
I wonder that if you did not fuse the foreground and background labels provided by the trimap during test.
Hi,
It's an amazing work.
Do you plan to share codes? (Training and Inference codes)
Thanks
Hello,
Thanks for the great work, I believe your research is really interesting!
I'd like to report issues obtaining the results displayed on the README.
FIrst, when running on the same environment as described on the README (with torch=1.4.0), I run in the following exception:
Traceback (most recent call last):
File "core/test.py", line 321, in <module>
load_model_and_deploy(args)
File "core/test.py", line 298, in load_model_and_deploy
ckpt = torch.load(args.model_path,map_location=torch.device('cpu'))
File "/home/malrick/anaconda3/envs/aim-env/lib/python3.6/site-packages/torch/serialization.py", line 527, in load
with _open_zipfile_reader(f) as opened_zipfile:
File "/home/malrick/anaconda3/envs/aim-env/lib/python3.6/site-packages/torch/serialization.py", line 224, in __init__
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022034529/work/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /opt/conda/conda-bld/pytorch_1579022034529/work/caffe2/serialize/inline_container.cc:132)
So it seems that the pre-trained model can't be read with such an old version of PyTorch. The export of the conda environment is available here: aim-env.txt (just replace .txt extension with .yml, GitHub doesn't support hosting yml files)
I then ran the script with the same environment that @JizhiziLi described in #7 (PyTorch 1.7.1, and python 3.7.7).
Conda export is available here: aim-env-2.txt
I have set the parameter test_choice=Hybrid
, in file core/scripts/test_samples.sh
, and I have the ratios set to global_ratio=1/4, local_ratio=1/2
in file core/test.py
Also, as my RTX3090 is not supported by these older versions of PyTorch, I used CPU for inference.
To do so, I had to adjust certain parts of the code, first at line 292:
if torch.cuda.device_count()==0:
print(f'Running on CPU...')
args.cuda = False
ckpt = torch.load(args.model_path,map_location=torch.device('cpu'))
else:
print(f'Running on GPU with CUDA as {args.cuda}...')
ckpt = torch.load(args.model_path)
to:
if not args.cuda or torch.cuda.device_count()==0:
print(f'Running on CPU...')
args.cuda = False
ckpt = torch.load(args.model_path,map_location=torch.device('cpu'))
else:
print(f'Running on GPU with CUDA as {args.cuda}...')
ckpt = torch.load(args.model_path)
and line 42:
tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda()
to
tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1)
if args.cuda:
tensor_img = tensor_img.cuda()
Removing the --cuda parameter in scripts/test_samples.sh
then allowed to run inference on CPU.
The results I'm obtaining are the same as in issue #7.
I also ran inference similarly on the AIM-500 dataset, and I obtained the following results in logs/test_logs/DEBUG.log
:
INFO:root:Testing numbers: 500
INFO:root:SAD: 70.16285879981845
INFO:root:MSE: 0.033924558738880374
INFO:root:MAD: 0.041689636977158
INFO:root:SAD TRIMAP: 48.64482158505621
INFO:root:MSE TRIMAP: 0.09556289661012443
INFO:root:MAD TRIMAP: 0.13346166896791076
INFO:root:SAD FG: 17.866437117307104
INFO:root:SAD BG: 3.651600097455164
INFO:root:CONN: 67.30147147906628
INFO:root:GRAD: 58.72859108352662
The artifacts that are obtained are really surprising, and they seem to be induced at a quite low resolution.
Hope this helps, and we can find a way to have this great repository to work!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. πππ
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google β€οΈ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.