Git Product home page Git Product logo

csgm-mri-langevin's Introduction

csgm-mri-langevin

NOTE: Please run all commands from the root directory of the repository, i.e from mri-langevin/

Setup environment

  1. python -m venv env
  2. source env/bin/activate
  3. pip install -U pip
  4. pip install -r requirements.txt
  5. git submodule update --init --recursive

Install BART for sensitivity map estimation

BART provides tools for processing MRI data. Our experiments require BART for estimating sensitivity maps, and BART can be installed using the following commands.

  1. sudo apt-get install make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev
  2. wget https://github.com/mrirecon/bart/archive/v0.6.00.tar.gz
  3. tar xzvf v0.6.00.tar.gz
  4. cd bart-0.6.00
  5. make

Download data and checkpoints

  1. gdown https://drive.google.com/uc?id=1vAIXf8n67yEAPmH2I9qiDWzmq9fGKPYL
  2. tar -zxvf checkpoint.tar.gz
  3. gdown https://drive.google.com/uc?id=1mpnV1iXid1PG0RaJswM6t9yI76b2IPxc
  4. tar -zxvf datasets.tar.gz

Script for estimating sensitivity maps from data

The script estimate_maps.py will estimate sensitivity maps. An example usage is

python estimate_maps.py --input-dir=datasets/brain_T2 --output-dir=datasets/brain_T2_maps

Example commands

We provide configuration files in configs/ that contain hyper-parameters used in our experiments. Here are example commands for using the configuration files.

  1. T2-Brains: python main.py +file=brain_T2
  2. T1-Brains: python main.py +file=brain_T1
  3. FLAIR-Brains: python main.py +file=brain_FLAIR
  4. fastMRI Knees: python main.py +file=knees
  5. Abdomens: python main.py +file=abdomen
  6. Stanford knees: python main.py +file=stanford_knees
  7. To run with horizontal measurements: python main.py +file=brain_T2 orientation=horizontal
  8. To run with random measurements: python main.py +file=brain_T2 pattern=random
  9. To change acceleration factor: python main.py +file=brain_T2 R=8

Plotting results

We use CometML to save results. Please see plot-demo.ipynb for example reconstructions.

Citations

If you find this repository useful, please consider citing the following papers:

@article{jalal2021robust,
  title={Robust Compressed Sensing MRI with Deep Generative Priors},
  author={Jalal, Ajil and Arvinte, Marius and Daras, Giannis and Price, Eric and Dimakis, Alexandros G and Tamir, Jonathan I},
  journal={Advances in Neural Information Processing Systems},
  year={2021}
}

@article{jalal2021instance,
  title={Instance-Optimal Compressed Sensing via Posterior Sampling},
  author={Jalal, Ajil and Karmalkar, Sushrut and Dimakis, Alexandros G and Price, Eric},
  journal={International Conference on Machine Learning},
  year={2021}
}

Our code uses prior work from the following papers, which must be cited:

@inproceedings{song2019generative,
  title={Generative modeling by estimating gradients of the data distribution},
  author={Song, Yang and Ermon, Stefano},
  booktitle={Advances in Neural Information Processing Systems},
  pages={11918--11930},
  year={2019}
}

@article{song2020improved,
  title={Improved Techniques for Training Score-Based Generative Models},
  author={Song, Yang and Ermon, Stefano},
  journal={arXiv preprint arXiv:2006.09011},
  year={2020}
}

We use data from the NYU fastMRI dataset, which must also be cited:

@inproceedings{zbontar2018fastMRI,
    title={{fastMRI}: An Open Dataset and Benchmarks for Accelerated {MRI}},
    author={Jure Zbontar and Florian Knoll and Anuroop Sriram and Tullie Murrell and Zhengnan Huang and Matthew J. Muckley and Aaron Defazio and Ruben Stern and Patricia Johnson and Mary Bruno and Marc Parente and Krzysztof J. Geras and Joe Katsnelson and Hersh Chandarana and Zizhao Zhang and Michal Drozdzal and Adriana Romero and Michael Rabbat and Pascal Vincent and Nafissa Yakubova and James Pinkerton and Duo Wang and Erich Owens and C. Lawrence Zitnick and Michael P. Recht and Daniel K. Sodickson and Yvonne W. Lui},
    journal = {ArXiv e-prints},
    archivePrefix = "arXiv",
    eprint = {1811.08839},
    year={2018}
}

@article{knoll2020fastmri,
  title={fastMRI: A publicly available raw k-space and DICOM dataset of knee images for accelerated MR image reconstruction using machine learning},
  author={Knoll, Florian and Zbontar, Jure and Sriram, Anuroop and Muckley, Matthew J and Bruno, Mary and Defazio, Aaron and Parente, Marc and Geras, Krzysztof J and Katsnelson, Joe and Chandarana, Hersh and others},
  journal={Radiology: Artificial Intelligence},
  volume={2},
  number={1},
  pages={e190007},
  year={2020},
  publisher={Radiological Society of North America}
}

csgm-mri-langevin's People

Contributors

ajiljalal avatar alex-adim avatar anonymous-bobo-neurips21 avatar mariusarvinte avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

csgm-mri-langevin's Issues

Training the score network

Hi, congratulations on this great work! I am trying to train the score network as you/your reference described in the paper, but I am running into out of memory issues (I am using a Quadro RTX 8000 with 45GB of memory). My dataset is 2D image slices of shape (208,240). The only way I can fit the training on memory is to reduce the ngf of the configuration file to 30 from 128 (I believe this is related to the number of channels of the convolutions in the blocks of the network?). I was hoping maybe you could shed some light on this/how you trained your network. Best regards, Thomas Yu

Compute FID

I noticed in the config file there is a langevin_config called fast_fid. Is there code implemented in this repository to compute FID? If I am missing something obvious, I apologize. Also, do you know approximately how long it takes to generate those 1000 samples for a given set of measurements?

On calculation of the score for the Gaussian likelihood

Hi, I am writing to consult you on the computation of likelihood score $\nabla_{x_{t}}\log\mu(y\mid x_{t})=\frac{A^{H}(y-Ax_{t})}{\gamma_{t}^{2}+\sigma^{2}}$ in eq. (4) of the corresponding paper.

Specifically, in the code part, I do not understand how you implement $\nabla_{x_{t}}\log\mu(y\mid x_{t})=\frac{A^{H}(y-Ax_{t})}{\gamma_{t}^{2}+\sigma^{2}}$. I guess the code for this is lines 136-152 in main.py, which I copy as follows for reference

                # get measurements for current estimate
                meas = forward_operator(normalize(samples, estimated_mvue))
                # compute gradient, i.e., gradient = A_adjoint * ( y - Ax_hat )
                # here A_adjoint also involves the sensitivity maps, hence the pointwise multiplication
                # also convert to real value since the ``complex'' image is a real-valued two channel image
                meas_grad = torch.view_as_real(torch.sum(self._ifft(meas-ref) * torch.conj(maps), axis=1)).permute(0,3,1,2)
                # re-normalize, since measuremenets are from a normalized estimate
                meas_grad = unnormalize(meas_grad, estimated_mvue)
                # convert to float incase it somehow became double
                meas_grad = meas_grad.type(torch.cuda.FloatTensor)
                meas_grad /= torch.norm( meas_grad )
                meas_grad *= torch.norm( p_grad )
                meas_grad *= self.config['mse']

                # combine measurement gradient, prior gradient and noise
                samples = samples + step_size * (p_grad - meas_grad) + noise

However, from the above code, I do not quite understand the correspondence between meas_grad with the formula $\nabla_{x_{t}}\log\mu(y\mid x_{t})=\frac{A^{H}(y-Ax_{t})}{\gamma_{t}^{2}+\sigma^{2}}$. For example, it seems that only the FFT matrix is considered in A, and I do not see the values of $\sigma^2$, and in particular the value of $\gamma_{t}^2$. Why there is operations like meas_grad /= torch.norm( meas_grad ) and meas_grad *= torch.norm( p_grad ), and something like meas_grad = unnormalize(meas_grad, estimated_mvue)?

I find it difficult to understand the code. I tried to write the code on my own as eq (4) but it does not work. The resultant $\nabla_{x_{t}}\log\mu(y\mid x_{t})=\frac{A^{H}(y-Ax_{t})}{\gamma_{t}^{2}+\sigma^{2}}$ is really big even if I choose a large $\gamma_{t}^{2}$ at the very beginning.

Could you please add some explanations on how you realize eq. (4) step by step in your code? It would be really appreciated if you can use some mathematical equations to illustrate it.

Thank you very much.

Best
Meng

Metrics script

Hi

Can you please provide the script to compute PSNR and SSIM scores from the pre-trained model?
Specifically, how do you do inference (say from Brain T2 - 1.5GB trained model) by taking sparse images (compressed MRI) and then reconstructing them and finally calculating SSIM/PSNR scores? How exactly do you do this?

Thanks.

checkpoint_100000.pth

Hello Dear team,
I was looking at your code and tried to implement on my dataset with fewer samples. But, I couldn't find the checkpoint_100000.pth weights in order to train my model. Did I make a mistake or you didn't provide the weights for public?

Thanks in advance,
Mahdi

Invariance to image shapes

Hi, it's a nice job!!, but I don't understand "Invariance to image shapes".
In other words, I don't understand that the model is trianed on T2W image, which size is 384x384, but it can sample knees data like size is 320x320. I don't find the answer in your code.
Can you explain this?
Thanks a lot.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.