mims-harvard / shepherd Goto Github PK
View Code? Open in Web Editor NEWSHEPHERD: Deep learning for diagnosing patients with rare genetic diseases
Home Page: https://zitniklab.hms.harvard.edu/projects/SHEPHERD
License: MIT License
SHEPHERD: Deep learning for diagnosing patients with rare genetic diseases
Home Page: https://zitniklab.hms.harvard.edu/projects/SHEPHERD
License: MIT License
CombinedGPAligner - train without checkpoint (best_ckpt arg) cause excpetion
When running train.py (train mode, not do_inference), the method get_model is being called with load_from_checkpoint=False which passes node_cpkt=None to CombinedGPAligner.
CombinedGPAligner loads from checkpoint the NodeEmbedder which loads_from_checkpoint the node_cpkt=None which cause exception!
This is different from the CombinedPatientNCA where NodeEmbedder reads the checkpoint from hparams['saved_checkpoint_path'] so I think its a bug.
def get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=False):
print("setting up model", hparams['model_type'])
# get patient model
if hparams['model_type'] == 'aligner':
if load_from_checkpoint:
comb_patient_model = CombinedGPAligner.load_from_checkpoint(checkpoint_path=str(Path(project_config.PROJECT_DIR / args.best_ckpt)),
edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, node_ckpt = hparams["saved_checkpoint_path"], node_hparams=node_hparams)
else:
comb_patient_model = CombinedGPAligner(edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, hparams=hparams, node_hparams=node_hparams)
class CombinedGPAligner(pl.LightningModule):
def __init__(self, edge_attr_dict, all_data, n_nodes=None, node_ckpt = None, hparams=None, node_hparams=None, spl_pca=[], spl_gate=[]):
super().__init__()
print('Initializing Model')
self.save_hyperparameters('hparams', ignore=["spl_pca", "spl_gate"]) # spl_pca and spl_gate never get used
print("Node checkpoint:", node_ckpt)
print('Saved combined model hyperparameters: ', self.hparams)
self.all_data = all_data
self.all_train_nodes = {}
self.train_patient_nodes = {}
self.train_sparse_nodes = {}
self.train_target_batch = {}
self.train_corr_gene_nid = {}
#print(f"Loading Node Embedder from {self.hparams.hparams['saved_checkpoint_path']}")
print(f"Loading Node Embedder from {node_ckpt}")
# NOTE: loads in saved hyperparameters
self.node_model = NodeEmbeder.load_from_checkpoint(checkpoint_path=node_ckpt, #self.hparams.hparams['saved_checkpoint_path'],
all_data=all_data, edge_attr_dict=edge_attr_dict,
num_nodes=n_nodes)
#num_nodes=n_nodes, combined_training=self.hparams.hparams['combined_training']) ```
Hi!
Great work! just wanted to inform that env setup using conda (as in tutorial) on local mac fails with the following prompt:
I'm still trying to understand why does it happen, but I raise an issue if its a general problem.
`Solving environment: failed
ResolvePackageNotFound:
Also, which code creates it?
(Trying to learn about your cohort-graph mapping)
Hello,
I am currently in the process of trying to evaluate the pre-trained SHEPHERD model for causal gene discovery on the myGene2 dataset. However, after computing the SPL matrix for this particular dataset using the add_spl_to_patients.py
script, I encounter an issue. It seems that when I subsequently run the train.py
script (with the flag --do_inference
for evaluation), an error occurs. Interestingly, I have discovered that this error can be resolved by swapping the names of the saved files, namely spl_index_fname
and spl_matrix_fname
as follows:
with open(str(project_config.PROJECT_DIR / 'patients' / spl_index_fname), 'wb') as handle:
pickle.dump(spl_indexing, handle, protocol=pickle.HIGHEST_PROTOCOL)
np.save(str(project_config.PROJECT_DIR / 'patients' / spl_matrix_fname), patients_spl_matrix)
Is there a specific reason for this behavior? I'm wondering if there is something I might be doing incorrectly.
Hello :)
In the NeighborSampler in the sample method, the batch initialize with the following code:
# sample nodes to form positive edges. we will try to predict these edges
row, col, e_id = self.adj_t_sample.coo()
# NOTE: only does self loops when no edges in the current partition of the dataset
target_batch = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1]
batch = torch.cat([source_batch, target_batch], dim=0)
batch_size: int = len(batch)
adjs = []
n_id = batch
for size in self.sizes:
adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
.....
For an isolated node u, the random_walk returns u in a walk of length 1, but for other nodes the walk will return a random neighbor.
Either way, the new batch, which now consist of duplicate amount of nodes, is then fed to the adj sampler. Hereby, sampling twice from the same nodes, or sampling neighbors from nodes that didn't appeared in the original batch.
I don't fully understand how this achieves self-loop for isolated nodes, wouldn't it make more sense just to add the missing edges to obtain self-loop?
Thank you :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.