Git Product home page Git Product logo

stem's Introduction

Welcome!

1. Introduction and installation

STEM is a tool for building single-cell level spatial transcriptomic landscapes using SC data with ST data. STEM extracts the spatial information from the gene expressions and eliminates the domain gap between spatial transcriptomics and single-cell RNA-seq data.

Installation

conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia ## more info: https://pytorch.org/get-started/locally/
conda install -c conda-forge scanpy python-igraph leidenalg
conda install seaborn

Since STEM is a light model, you can use the code from the STEM folder directly. You can also install STEM from PYPI:pip install scSTEM. To verify the installation, run python test.py

Then we will demonstrate the workflow of generating the results shown in our figure 2.

%pylab inline
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import scanpy as sc
import pandas as pd
import torch
import scipy
import time
from STEM.model import *
from STEM.utils import *

2. Load and Preprocess data

First, we load the simulated ST data as ST data and the raw SeqFISH data as SC data. Then we normalized and log-scaled these data.

scdata = pd.read_csv('./data/mousedata_2020/E1z2/simu_sc_counts.csv',index_col=0)
scdata = scdata.T
stdata = pd.read_csv('data/mousedata_2020/E1z2/simu_st_counts.csv',index_col=0)
stdata = stdata.T
stgtcelltype = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_celltype.csv',index_col=0)
spcoor = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_metadata.csv',index_col=0)
scmetadata = pd.read_csv('./data/mousedata_2020/E1z2/metadata.csv',index_col=0)

adata = sc.AnnData(scdata,obs=scmetadata)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
scdata = pd.DataFrame(adata.X,index=adata.obs_names,columns=adata.var_names)
stadata = sc.AnnData(stdata)
sc.pp.normalize_total(stadata)
sc.pp.log1p(stadata)
stdata = pd.DataFrame(stadata.X,index=stadata.obs_names,columns=stadata.var_names)

adata.obsm['spatial'] = scmetadata[['x_global','y_global']].values
stadata.obsm['spatial'] = spcoor

Next we calculate the ratio between the median total counts value of SC and ST data as the dropout rate.

sc.pp.calculate_qc_metrics(adata,percent_top=None, log1p=False, inplace=True)
adata.obs['n_genes_by_counts'].median()

sc.pp.calculate_qc_metrics(stadata,percent_top=None, log1p=False, inplace=True)
stadata.obs['n_genes_by_counts'].median()

dp = 1- adata.obs['n_genes_by_counts'].median()/stadata.obs['n_genes_by_counts'].median()
#0.5836734693877551

3. Train the STEM model

We first config the STEM model and then train it. Empirically we found by setting the sigma as half of the ST spot adjacent distance, STEM achieves the best performance.

class setting( object ):
    pass
seed_all(2022)
opt= setting()
setattr(opt, 'device', 'cuda:0') # device
setattr(opt, 'outf', 'log/test') # folder to save log files
setattr(opt, 'n_genes', 351) # number of genes for the input
setattr(opt, 'no_bn', False) # duplicated
setattr(opt, 'lr', 0.002) # learning rate
setattr(opt, 'sigma', 3)  # the spatial variance parameter in the Gaussian function
setattr(opt, 'alpha', 0.8) # MMD loss weight default:0.8
setattr(opt, 'verbose', True) # verbose
setattr(opt, 'mmdbatch', 1000) # batch for MMD loss
setattr(opt, 'dp', dp) # dropout rate for ST data

testmodel = SOmodel(opt)
testmodel.togpu()
loss_curve = testmodel.train_wholedata(400,torch.tensor(scdata.values).float(),torch.tensor(stdata.values).float(),torch.tensor(spcoor.values).float())

The loss curve will be like this: loss

4. Get embeddings and reconstruct spatial adjacency

We first get the embeddings and build the mapping matrix We first get the embeddings and build the mapping matrix

testmodel.modeleval()
scembedding = testmodel.netE(torch.tensor(scdata.values,dtype=torch.float32).cuda())
stembedding = testmodel.netE(torch.tensor(stdata.values,dtype=torch.float32).cuda())
netst2sc = F.softmax(stembedding.mm(scembedding.t()),dim=1).detach().cpu().numpy()
netsc2st = F.softmax(scembedding.mm(stembedding.t()),dim=1).detach().cpu().numpy()

The matrix netst2sc and netsc2st are ST-SC and SC-ST mapping matrices, respectively. In the ST-SC mapping matrix, the probability of one spot to other cells summarizes to 1. In the SC-ST mapping matrix, the probability of one cell to other spots summarizes to 1.

Then we can get the spatial coordinate for every single cell.

adata.obsm['spatialDA'] = all_coord(pd.DataFrame(netsc2st,index=adata.obs_names,columns=stadata.obs_names),spcoor)

Compared with other methods, STEM is the only method that preserves the original topology structure of all single cells. loss

More demos can be found in the Demo folder. Codes for repruducing the results are in the SourceforFigure folder.

Data

The processed data and trained models used for reproducing the results are deposited in Figshare.

stem's People

Contributors

whirlfirst avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

single-cell-hx

stem's Issues

torch type

Hi,
I run the test.py for pre test for this model. I got the "RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward".

I thought it was the data type problem but changed the float type to long. It still does not work. By the way, I have tried the data in this GitHub and the others.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.