Git Product home page Git Product logo

recursions-are-all-you-need's Introduction

Recursions Are All You Need

This repository holds the official code used in the paper:

Recursions Are All You Need: Towards Efficient Deep Unfolding Networks

The code was run on a Linux-based system (Ubuntu 22.04) on a single Nvidia RTX 3090 GPU, and was written using PyTorch.

Abstract

The use of deep unfolding networks in compressive sensing (CS) has seen wide success as they provide both simplicity and interpretability. However, since most deep unfolding networks are iterative, this incurs significant redundancies in the network. In this work, we propose a novel recursion-based framework to enhance the efficiency of deep unfolding models. First, recursions are used to effectively eliminate the redundancies in deep unfolding networks. Secondly, we randomize the number of recursions during training to decrease the overall training time. Finally, to effectively utilize the power of recursions, we introduce a learnable unit to modulate the features of the model based on both the total number of iterations and the current iteration index. To evaluate the proposed framework, we apply it to both ISTA-Net+ and COAST. Extensive testing shows that our proposed framework allows the network to cut down as much as 75% of its learnable parameters while mostly maintaining its performance, and at the same time, it cuts around 21% and 42% from the training time for ISTA-Net+ and COAST respectively. Moreover, when presented with a limited training dataset, the recursive models match or even outperform their respective non-recursive baseline.

Recursive_Framework Figure 1: General architecture of the recursive framework. Compared to general deep unfolding models such as COAST and ISTA-Net+, $R_i$ recursions are used in each recovery block $i$ in the recovery subnet.

Training Setup

Download the training data from here to the data directory and then run COAST/TRAIN_COAST.py.

Arguements:

Arguements Description Default Value
--start_epoch Starting epoch number 0
--end_epoch Final epoch number 400 for COAST and 200 for ISTA-Net+
--RFMU Adds the RFMU unit True
--layer_num Number of recovery blocks 5 and 3 for COAST and ISTA-Net+ respectively
--IPL Number of iterations per layer 4 for COAST and 3 for ISTA-Net+
--learning_rate Sets the learning rate of the Adam optimizer 1e-4
--gpu_list Selects the GPUs to be used during training (not tested for more than one GPU) '0'
--num_workers Number of workers used in the data loader 10
--matrix_dir Path to the sampling matrices 'sampling_matrix'
--model_dir Path to the trained model (not working currently) N/A
--data_dir Path to the directory holding the data (whether it is for training or validation) 'data'
--validation_name Validation Dataset (can be Set11, BSD68, BSD100, or Urban100) 'Set11'
--save_cycle Save cycle period to save the model weights (models acheving the best PSNR or SSIM scores are always saved immediately regardless of the save cycle) 10

Testing

Run COAST/TEST_COAST.py and it will print out the results of all the configurations of COAST used in the paper (it is recommended to run them cell by cell).

Results

Tables

Acknowledgement

  • Author(s) would like to acknowledge the support received from Saudi Data and AI Authority (SDAIA) and King Fahd University of Petroleum and Minerals (KFUPM) under SDAIA-KFUPM Joint Research Center for Artificial Intelligence.
  • In addition, we would like to thank the authors of the papers ISTA-Net and COAST for open-sourcing their code. This was very helpful in our work and our code borrows heavily from them.

recursions-are-all-you-need's People

Contributors

rawwad-alhejaili avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.