Git Product home page Git Product logo

nmce-release's Introduction

Neural Manifold Clustering and Embedding (NMCE)

Code for paper Neural Manifold Clustering and Embedding (NMCE)

drawing

NMCE uses MCR2 objective to perform unsupervised clustering and embedding learning on non-linear manifolds. Data-augmentation is used to enforce constraint that makes the problem tractable. In the example above, the locality constraint is enforced with Gaussian noise augmentation, for image clustering, standard self-supervise learning augmentation is used.

Requirements

scipy
numpy
torch
sklearn
matplotlib
torchvision
tqdm

Experiments

The double spiral toy example is included in NMCE_toy.

COIL20 and COIL100:

First, decompresse the original datasets and run the following command to convert them:

python NMCE/convert_imgs.py --COIL_20_path /path_to_coil-20 --COIL_100_path /path_to_coil-100

Train COIL-20:

python NMCE/train_COIL20.py

Train COIL-100 (will take a while):

python NMCE/train_COIL100.py

CIFAR-10 with ResNet-18:

To select which GPU(s) to use, pass list of gpus to --gpu_ids argument. Specify path to dataset with --data_dir argument

Stage 1: Self-supervised learning with TCR (Total Coding Rate) objective:

python NMCE/train_selfsup.py --arch resnet18cifar --data cifar10 --data_dir ../../data/ --aug_name cifar_simclr_norm --loss totalcodingrate --z_dim 128 --epo 600 --bs 1024 --lr 0.3 --wd 1e-4 --eps 0.2 --z_weight 30. --gpu_ids [0] --fp16 --doc tcr_zw30

Evaluate network trained with TCR:

python NMCE/evaluate.py --arch resnet18cifar --data cifar10 --aug_name cifar_simclr_norm --feature_type proj --z_dim 128 --load_ep 600  --gpu_ids [0] --aug_avg 16 --doc tcr_zw30 --svm --knn --nearsub

Stage 2: Clustering with backbone frozen:

python NMCE/train_clustering.py --arch resnet18cifar --data cifar10 --data_dir ../../data/ --aug_name cifar_simclr_norm --z_dim 128 --epo 100 --bs 1024 --lr 0.3 --eps 0.2 --z_weight 0. --wd1 0.005 --wd2 0.005 --gpu_ids [0] --doc tcr_zw30 --load_ep 600 --seed 42

Stage 3: Fine-tune backbone with full NMCE objective:

python NMCE/train_clustering.py --arch resnet18cifar --data cifar10 --data_dir ../../data/ --aug_name cifar_simclr_norm --z_dim 128 --epo 100 --bs 1024 --lr 0.003 --eps 0.2 --z_weight 0. --wd1 0.0001 --wd2 0.0001 --gpu_ids [0] --doc tcr_zw30 --load_ep 700 --train_backbone 

Cluster accuracy should be ~83%.

Acknowledgment:

This repo borrowed significantly from MCR2 and solo-learn repo.

nmce-release's People

Contributors

zengyi-li avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

nmce-release's Issues

WHEN will the code come?

I'm very confused why this empty repo can get 13 stars? I hope that the authors will make the code public!

MNIST clustering

Could you please share your clustering accuracy on MNIST dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.