Git Product home page Git Product logo

gnn_tracking's Introduction

GNNs for Charged Particle Tracking

DOI CalVer YY.0M.MICRO Documentation Status pre-commit.ci status gh actions Check Markdown links codecov

This repository holds the main python package for the GNN Tracking project. See the readme of the organization for an overview of the task. Detailed write-ups of our progress are available in arXiv:2309.16754 and arXiv:2312.03823. More resources are provided in the reading list here.

  • ๐Ÿ”‹ Batteries included: This repository implements a hole pipeline: from preprocessing to models, to the evaluation of the final performance metrics.
  • โšก Built around pytorch lightning, our models are easy to train and to restore. By using hooks and callbacks, everything remains modular and maintainable.
  • โœ… Tested: Most of the code is guaranteed to run

๐Ÿ”ฅ Installation

  1. Install micromamba (installation instructions). Conda works as well, but will be slow to solve the environment, so it's not recommended.
  2. Set up your environment with one of the environment/*.yml files (see the readme in that folder)
  3. Run pip3 install -e '.[testing,dev]' from this directory.
  4. Run pytest from this directory to check if everything worked
  5. For development: Install pre-commit hooks: pre-commit install (from this directory)

A good place to get started are the demo notebooks. This package is versioned as CalVer YY.0M.MICRO.

๐Ÿงฐ Development guidelines

If you open a PR and pre-commit fails for formatting, commentpre-commit.ci autofix to trigger a fixup commit from pre-commit.

To skip the slowest tests with pytest, run pytest --no-slow.

๐Ÿ’š Contributing, contact, citation

You can reach us per mail. You can cite this software with the zenodo DOI. Please also cite our [latest preprint][preprint].

A good place to start contributing are the issues marked with 'good first issue'. It is always best to have the issue assigned to you before starting to work on it.

Core developers (emoji key):

Gage DeZoort
Gage DeZoort

๐Ÿ’ป ๐Ÿค”
Kilian Lieret
Kilian Lieret

๐Ÿ’ป โš ๏ธ

Thanks also goes to these wonderful people:

Shubhanshu Saxena
Shubhanshu Saxena

๐Ÿ’ป
Geo Jolly
Geo Jolly

โš ๏ธ
Jian Park
Jian Park

๐Ÿ’ป ๐Ÿค”
Devdoot Chatterjee
Devdoot Chatterjee

๐Ÿ’ป ๐Ÿ”ฌ
Add your contributions

This project follows the all-contributors specification. Contributions of any kind welcome!

gnn_tracking's People

Contributors

dependabot[bot] avatar gagedezoort avatar jnpark3 avatar kingjuno avatar klieret avatar pre-commit-ci[bot] avatar shubhanshu02 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

gnn_tracking's Issues

Remove `device` argument of loss functions

If we really need to know the device, we can get it from whatever is passed in to it, but actually right now the argument is never used at all, because we only create tensors from tensors that are passed in.

Not everything gets moved to GPU

(TCNTrainable pid=14192) INFO: Loading data
(TCNTrainable pid=14192) INFO: ---- Epoch 1 ----
2022-09-23 10:18:07,853 ERROR trial_runner.py:980 -- Trial TCNTrainable_8139afc9: Error processing event.
ray.exceptions.RayTaskError(RuntimeError): ray::ResourceTrainable.train() (pid=14192, ip=10.36.22.16, repr=<ray.tune.trainable.util.TCNTrainable object at 0x2ba27f95dd60>)
  File "/scratch/gpfs/kl5675/miniconda3/envs/gnn/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 347, in train
    result = self.step()
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking_experiments//scripts/tune.py", line 106, in step
    return self.trainer.step(max_batches=self.config.get("max_batches", None))
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/training/tcn_trainer.py", line 330, in step
    self.train_step(max_batches=max_batches)
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/training/tcn_trainer.py", line 254, in train_step
    batch_loss, batch_losses = self.get_batch_losses(model_output)
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/training/tcn_trainer.py", line 174, in get_batch_losses
    loss = loss_func(**model_output)
  File "/scratch/gpfs/kl5675/miniconda3/envs/gnn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/utils/losses.py", line 130, in forward
    return self.condensation_loss(beta=beta, x=x, particle_id=particle_id)
  File "/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/utils/losses.py", line 109, in condensation_loss
    diff = x[:, :, None] - x_alphas[None, :, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
wandb: Waiting for W&B process to finish... (success).

Refactor trainer

  1. Loss functions to take parameter keyword only
  2. Trainer takes loss functions directly as objects
  3. Ideally, trainer takes loss functions as dictionary as to make it easier to add/remove additional loss terms

Does pruning of clustering HPO make sense?

Basically we're forcing the clustering to perform good even on the first few sectors. This might mean that we prune good hyperparameters for clustering. Should have a higher warmup time or lower fraction of trials pruned.

Alter validation procedure

The validation procedure needs to be updated to select the best model (schematically, best_epoch=argmax(val_acc)). We would then test this model on the test data downstream, perhaps even outside of the training loop.

Undefined variable

def empty_graph():
    singles = {k: np.array([]) for k in ["x", "hit_id", "particle_id", "y"]}
    doubles = {k: np.array([[], []]) for k in ["edge_index", "edge_hit_id"]}
    quadrouples = {"edge_attr": np.array([[], [], [], []])}
    graph = {**singles, **doubles, **quadrouples}
    graph["s"] = s
    graph["n_incorrect"] = 0
    return graph

s in the 3rd to last line not defined src/gnn_tracking/utils/graph_construction.py

Increase GPU utilization: Data loader parallelization?

Current statistics show us at 33% only. Probably have to parallelize data loader more (and assign it more cores)

================================================================================
                              Slurm Job Statistics
================================================================================
         Job ID: 9446770
  NetID/Account: kl5675/physics
       Job Name: tune.sh
          State: RUNNING
          Nodes: 2
      CPU Cores: 56
     CPU Memory: 8GB (142.9MB per CPU-core)
           GPUs: 8
  QOS/Partition: gpu-test/gpu
        Cluster: tiger
     Start Time: Fri Oct 7, 2022 at 12:48 PM
       Run Time: 00:16:10 (in progress)
     Time Limit: 01:00:00

                              Overall Utilization
================================================================================
  CPU utilization  [||||||                                         12%]
  CPU memory usage [||||||||||||||||||||||||||||||                 60%]
  GPU utilization  [||||||||||||||||                               33%]
  GPU memory usage [|||||||||||||||||||||||||||||||||||||          74%]

                              Detailed Utilization
================================================================================
  CPU utilization per node (CPU time used/run time)
      tiger-i19g14: 00:51:49/07:32:45 (efficiency=11.4%)
      tiger-i19g3: 00:52:56/07:32:45 (efficiency=11.7%)
  Total used/runtime: 01:44:45/15:05:30, efficiency=11.6%

  CPU memory usage per node - used/allocated
      tiger-i19g14: 64.5GB/109.4GB (2.3GB/3.9GB per core of 28)
      tiger-i19g3: 66.0GB/109.4GB (2.4GB/3.9GB per core of 28)
  Total used/allocated: 130.4GB/218.8GB (2.3GB/3.9GB per core of 56)

  GPU utilization per node
      tiger-i19g14 (GPU 0): 43.1%
      tiger-i19g14 (GPU 1): 44.4%
      tiger-i19g14 (GPU 2): 16.2%
      tiger-i19g14 (GPU 3): 27.1%
      tiger-i19g3 (GPU 0): 29.9%
      tiger-i19g3 (GPU 1): 47.0%
      tiger-i19g3 (GPU 2): 30.6%
      tiger-i19g3 (GPU 3): 27.7%

  GPU memory usage per node - maximum used/total
      tiger-i19g14 (GPU 0): 15.1GB/15.9GB (95.0%)
      tiger-i19g14 (GPU 1): 13.3GB/15.9GB (83.6%)
      tiger-i19g14 (GPU 2): 6.2GB/15.9GB (38.9%)
      tiger-i19g14 (GPU 3): 12.0GB/15.9GB (75.4%)
      tiger-i19g3 (GPU 0): 13.2GB/15.9GB (83.3%)
      tiger-i19g3 (GPU 1): 13.2GB/15.9GB (82.9%)
      tiger-i19g3 (GPU 2): 6.3GB/15.9GB (39.7%)
      tiger-i19g3 (GPU 3): 15.1GB/15.9GB (94.8%)

                                     Notes
================================================================================
  * This job ran in the gpu-test QOS. Each user can only run a small number of
    jobs simultaneously in this QOS. For more info:
    https://researchcomputing.princeton.edu/support/knowledge-base/job-priority#test-queue

  * For additional job metrics including metrics plotted against time:
    https://stats.rc.princeton.edu  (VPN required off-campus)

hidden_size parameter

The track_condenssation_network files use parameter hidden_size of IN, but that doesn't exist anymore (probably now relational_hidden_size and object_hidden_size)

To-Do Items

  • Add focal loss and relevant hyperparameters.
  • Remove predict_track_params flag from trainer and models.
  • Fix the point cloud TCN model.
  • Add separate encoders / GNN layers for edge classification and object condensation in the graph TCN model.
  • Add loss function weights (fixed by default, but we could add them as learnable parameters later) to the trainer.
  • Add trainer plots to the plotter class.
  • Separate the attractive and repulsive potentials if possible.
  • Fix the numerator and/or denominator in the graph building efficiency measurements; they seem to max out at 99% right now.

Relying more on ray?

If we want to get rid of tracking losses per epoch by ourselves with ray's report function, we also have to use its training call which would probably look somewhat like so:

scaling_config = ScalingConfig(num_workers=3)
# If using GPUs, use the below scaling config instead.
# scaling_config = ScalingConfig(num_workers=3, use_gpu=True)
trainer = TorchTrainer(
    train_loop_per_worker=tcn_trainer.train,
    scaling_config=scaling_config,
)
result = trainer.fit()

Advantage is that we might get something out of the scaling config? And that it then becomes apparently very, very easy to hook up tensorboard, mlflow and more by simply adding

run_config=RunConfig(
        callbacks=[
            MLflowLoggerCallback(experiment_name="train_experiment"),
            TBXLoggerCallback(),
        ],
    ),

to the above (more info).

Training edge classification head separately

  • Do we want to keep the BCE loss which only is the preproc step as prominent in the loss function? (is it obvious that maximizing the BCE loss will forever be helpful to maximize how well the OC works?)
  • Do we want to speed up training by only training the edge classification at the beginning and keeping the OC fixed?

Add integeration tests for graph builder / trainer

I have the code, but I still run into issues because the training set is very small and nothing remains after some hidden cuts are applied (I already relaxed pt cuts and everything else I found), causing exceptions...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.