This repository implements a minimum necessary source code for creating, training and evaluating CALNs documented in the paper: Convolutional Adaptive Logic Networks: A First Approach.
The code in this repository is tested under the following requirements:
- python 3.9
- pytorch: 1.11.0 (stable)
- cuda: 10.2
- tensorboard: 2.0.0
- scikit-learn: 1.0.2
- matplotlib: 3.5.1
- prettytable: 3.2.0
- wandb: 0.12.15 (only required if WandB is used for logging and monitoring)
This is a development codebase. Therefore, it is suggested to install the code in the caln subdirectory as a Python development package:
cd caln
pip install -e .
The main class that implements CALN is ConvolutionALNet
in caln/alntorch/core/caln.py. A sample usage is given in caln/alntorch/trainings/models.py. Basically, ConvolutionALNet
receives a backbone and attaches an ALN at its end. Note that the ALN weights are not updated by a gradient descent method. However, the gradients are propagated through the ALN back to the backbone weights.
forward
method receives a tensor and runs the entire network to produce the output.adapt
should be used at each training iteration to update ALN weights.grow
should be used at a split iteration.
Running the experiments from the paper using the commands below assumes that the current directory is caln/alntorch/trainings.
The main code that trains a variety of CALNs on CIFAR-10 dataset is in train_cifar10.py. The command line input options to this script are described in common_utils.py. You can also see the list by entering $ python train_cifar10.py
at the command line.
Note that train_cifar10.py can also be used to train a couple of ResNet architectures. You can use the following command lines to approximately reproduce the reported results in the paper (all of the experiments run on a GPU):
-
ResNet13+ALN:
python train_cifar10.py --name CALN_ResNet13_Cifar10 --model CALN --optimizer SGD --epochs 1000 --lr 0.1 --aln_lr 0.01 --init_pieces 3 --root_op min --split_step 15 --max_splits 1 --split_step_increment 2 --device cuda:0
-
ResNet14:
python train_cifar10.py --name ResNet14_Cifar10 --model ResNet14 --optimizer SGD --epochs 1000 --device cuda:0
-
ResNet18:
python train_cifar10.py --name ResNet18_Cifar10 --model ResNet18 --optimizer SGD --epochs 1000 --device cuda:0
The training logs are stored in the default (running) folder at ./[NAME]/, e.g. for ResNet13+ALN, it is stored in ./CALN_ResNet13_Cifar10/. Use --logdir
to change the default ./
. To prevent losing past experiments, existing directories cannot be overwritten. The only directory that can be overwrriten automatically is test
. That is, one can use test
as the experiment name and the previous test experiments are overwritten.
The CIFAR-10 dataset is sought in the root folder ./
; if it does not exist, it is automatically downloades. To change the path for the CIFAR-10 dataset use --cifar10_path [NEW_PATH]
argument.
To monitor the training process, the default tool is Tensorboard. However, WandB is also supported and we strongly suggest using WandB because hyperparameters and more details are logged using WandB. An example usage with WandB is:
python train_cifar10.py --name CALN_ResNet13_Cifar10 --model CALN --optimizer SGD --epochs 1000 --lr 0.1 --aln_lr 0.01 --init_pieces 3 --root_op min --split_step 15 --max_splits 1 --split_step_increment 2 --device cuda:0 --logger wandb --wandb_project [PROJECT_NAME] --wandb_entity [ENTITY_NAME]
,
where [PROJECT_NAME]
and [ENTITY_NAME]
must be set properly (refer to WandB documentation).
The main bottleneck of the current implementation is the for-loop usage while evaluating and adapting the ALNs. Because the ALNs are independent, in principle it is straightforward to run them in parallel. This will significantly reduces training and evaluation time. It also brings the possibility to run the CALNs on datasets with larger number of classes.
We have not tested any form of training and evaluation across multiple GPUs (model or data parallel).
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.