Git Product home page Git Product logo

convolutionalns's Introduction

Convolutional Adaptive Logic Networks (CALNs)

Introduction

This repository implements a minimum necessary source code for creating, training and evaluating CALNs documented in the paper: Convolutional Adaptive Logic Networks: A First Approach.

Instructions for usage

Requirements

The code in this repository is tested under the following requirements:

  • python 3.9
  • pytorch: 1.11.0 (stable)
  • cuda: 10.2
  • tensorboard: 2.0.0
  • scikit-learn: 1.0.2
  • matplotlib: 3.5.1
  • prettytable: 3.2.0
  • wandb: 0.12.15 (only required if WandB is used for logging and monitoring)

Installation

This is a development codebase. Therefore, it is suggested to install the code in the caln subdirectory as a Python development package:

cd caln
pip install -e .

General usage and guidelines

The main class that implements CALN is ConvolutionALNet in caln/alntorch/core/caln.py. A sample usage is given in caln/alntorch/trainings/models.py. Basically, ConvolutionALNet receives a backbone and attaches an ALN at its end. Note that the ALN weights are not updated by a gradient descent method. However, the gradients are propagated through the ALN back to the backbone weights.

  • forward method receives a tensor and runs the entire network to produce the output.
  • adapt should be used at each training iteration to update ALN weights.
  • grow should be used at a split iteration.

Experiments

Running the experiments from the paper using the commands below assumes that the current directory is caln/alntorch/trainings.

The main code that trains a variety of CALNs on CIFAR-10 dataset is in train_cifar10.py. The command line input options to this script are described in common_utils.py. You can also see the list by entering $ python train_cifar10.py at the command line.

Note that train_cifar10.py can also be used to train a couple of ResNet architectures. You can use the following command lines to approximately reproduce the reported results in the paper (all of the experiments run on a GPU):

  • ResNet13+ALN: python train_cifar10.py --name CALN_ResNet13_Cifar10 --model CALN --optimizer SGD --epochs 1000 --lr 0.1 --aln_lr 0.01 --init_pieces 3 --root_op min --split_step 15 --max_splits 1 --split_step_increment 2 --device cuda:0

  • ResNet14: python train_cifar10.py --name ResNet14_Cifar10 --model ResNet14 --optimizer SGD --epochs 1000 --device cuda:0

  • ResNet18: python train_cifar10.py --name ResNet18_Cifar10 --model ResNet18 --optimizer SGD --epochs 1000 --device cuda:0

The training logs are stored in the default (running) folder at ./[NAME]/, e.g. for ResNet13+ALN, it is stored in ./CALN_ResNet13_Cifar10/. Use --logdir to change the default ./. To prevent losing past experiments, existing directories cannot be overwritten. The only directory that can be overwrriten automatically is test. That is, one can use test as the experiment name and the previous test experiments are overwritten.

The CIFAR-10 dataset is sought in the root folder ./; if it does not exist, it is automatically downloades. To change the path for the CIFAR-10 dataset use --cifar10_path [NEW_PATH] argument.

Monitoring training with Tensorboard or WandB

To monitor the training process, the default tool is Tensorboard. However, WandB is also supported and we strongly suggest using WandB because hyperparameters and more details are logged using WandB. An example usage with WandB is:

python train_cifar10.py --name CALN_ResNet13_Cifar10 --model CALN --optimizer SGD --epochs 1000 --lr 0.1 --aln_lr 0.01 --init_pieces 3 --root_op min --split_step 15 --max_splits 1 --split_step_increment 2 --device cuda:0 --logger wandb --wandb_project [PROJECT_NAME] --wandb_entity [ENTITY_NAME],

where [PROJECT_NAME] and [ENTITY_NAME] must be set properly (refer to WandB documentation).

Known issues and suggestions for improvements

The main bottleneck of the current implementation is the for-loop usage while evaluating and adapting the ALNs. Because the ALNs are independent, in principle it is straightforward to run them in parallel. This will significantly reduces training and evaluation time. It also brings the possibility to run the CALNs on datasets with larger number of classes.

We have not tested any form of training and evaluation across multiple GPUs (model or data parallel).

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

convolutionalns's People

Contributors

microsoftopensource avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

sarvex

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.