Git Product home page Git Product logo

tl-krr's Introduction

Transfer Learning with Kernel Ridge Regression (TL-KRR)


Here comes the implementation of our paper on transfer learning with kernel ridge regression, and it doesn't require finetuning the base model.

In our experiments, we primarily tested on transferring from ResNet models pretrained on the ImageNet dataset to six downstream tasks, including CIFAR10, CIFAR100, STL10, CUB200, SVHN and Kuzushiji49.

The details of our method in four steps are presented in the paper:

Transfer Learning with Kernel Ridge Regression

by Shuai Tang and Virginia R. de Sa

Brief Introduction

The Implementation relies on the following files:

TLKRR.py is the main file that conducts our four-step transfer learning method with kernel ridge regression.

sketched_kernels.py sketches the feature vectors at individual layers into a fixed number of buckets. It now supports feature hashing as the first step to reduce the dimensionality of feature vectors so that one could deploy deeper models, and then applies competitive learning algorithm or SJLT to find a subset of samples for Nyström.

lowrank_feats.py applies the Nyström method on top of feature vectors to compute low-rank approximations.

learning_kernel_alignment.py computes the optimal convex combination of feature vectors from individual layers that gives the highest alignment score with the target in a downstream task.

utils.py has helper functions.

cub200.py and kuzushiji49.py implement the PyTorch vision dataset class for CUB200 and Kuzushiji49, respectively.

Requirements

python >= 3.5
torch >= 1.0
torchvision
numpy
scipy
sklearn
pandas

Simple Example

CUDA_VISIBLE_DEVICES=0 python3 -u TLKRR.py \
    --datapath data/ \
    --modelname resnet18 \
    --task cifar100 \
    --bsize 800

To reduce the memory consumption, one could do the following:

CUDA_VISIBLE_DEVICES=0 python3 -u TLKRR.py \
    --datapath data/ \
    --modelname wide_resnet50_2 \
    --task cifar100 \
    --bsize 400 \
    --feature_hashing --factor 4

One could set factor or/and M to a large number to get decent performance.

The learning rate of the competitive learning algorithm for finding a subset of samples can be tuned to get good performance as well. The default value is a personal random guess, and it already works.

Results

Hyperparameter settings:

{
    "subsampling": "sjlt",
    "seed": 0,
    "bsize": 800,
    "M": 2048,
    "T": 4,
    "feature_hashing": true,
    "factor": 4,
}
Models/Tasks CIFAR10 CIFAR100 STL10 SVHN KUZUSHIJI49
ResNet18 90.86 71.59 96.14 87.93 89.15
ResNet34 91.98 74.27 97.21 84.23 86.46
ResNet50 91.39 73.87 97.66 88.69 88.42
ResNet101 93.18 76.43 98.36 90.66 84.12
ResNet152 93.81 77.48 98.60 90.26 86.08
ResNeXt50 92.42 75.23 98.00 88.53 84.00
ResNeXt101 93.84 78.06 98.48 88.15 85.70
Wide-ResNet50 92.41 74.40 98.39 91.58 90.25
Wide-ResNet101 93.77 75.74 98.51 90.85 90.42

Hyperparameter settings:

{
    "subsampling": "competitive_learning",
    "seed": 0,
    "bsize": 800,
    "M": 2048,
    "T": 4,
    "feature_hashing": true,
    "factor": 4,
}
Models/Tasks CIFAR10 CIFAR100 STL10 SVHN KUZUSHIJI49
ResNet18 90.89 70.59 96.39 89.62 92.23
ResNet34 92.64 73.56 97.25 87.33 90.64
ResNet50 91.42 74.74 97.76 90.29 91.89
ResNet101 93.09 75.29 98.34 92.00 90.51
ResNet152 93.50 76.94 98.54 91.23
ResNeXt50 92.51 73.92 98.10 90.16 89.53
ResNeXt101 94.22 76.75 98.55 89.74 90.69
Wide-ResNet50 92.42 73.85 98.31 91.53 92.11
Wide-ResNet101 93.54 75.92 98.48 92.40

Authors

Shuai Tang

Acknowledgements

We gratefully thank Charlie Dickens and Wesley J. Maddox for fruitful discussions, and appreciate Richard Gao, Mengting Wan and Shi Feng for comments on the draft. Huge amount of thanks to my advisor --- Virginia de Sa --- for basically allowing me to do whatever I am interested in. :)

Rant

If the number of data samples is extremely small, then one should skip the first step of approximating low-rank features, otherwise the undersampling issue would occur and hurt the performance.

For memory concern, one could set the precision of generated feature vectors to half-precision floating-point, and it gives a minor performance drop.

Competitive learning is an online clustering algorithm, and it produces M centroids as the subset for Nyström approximation. Since the subset is learnt and not data-oblivious, it gives better performance sometimes than data-oblivious SJLT-based sketching does on SVHN or Kuzushiji49. However, the learning rate of it becomes a hyperparameter. 🤷‍♂️

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.