Git Product home page Git Product logo

cast's Introduction

Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment

This repo provides the implementation code corresponding to our NeurIPS-23 paper entitled Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment. The code is implemented on Pytorch 1.10.2 on a server with NVIDIA RTX A6000.

image

Description

We present CaST, a new framework that takes a causal look into Spatio-Temporal Graph (STG) forecasting, tackling temporal out-of-distribution issues and dynamic spatial causation. We employ two causal tools: Back-door adjustment, implemented through a disentanglement block to distinguish invariant parts from temporal environments, and Front-door adjustment, which introduces a surrogate variable to emulate node information filtered based on their causal relationships.

Requirements

CaST uses the following dependencies:

  • Pytorch 1.10.2 and its dependencies
  • Numpy and Scipy
  • CUDA 11.3 or latest version, cuDNN

Dataset

Overview

The performance of CaST was validated using three datasets: PEMS08, AIR-BJ, and AIR-GZ. AIR-BJ and AIR-GZ contain one-year PM$_{2.5}$ readings obtained from air quality monitoring stations located in Beijing and Guangzhou, respectively. PEMS08 contains traffic flow data that was collected by sensors deployed on the road network. Traffic flow data is often considered to be a complex and challenging type of spatio-temporal data due to the numerous factors that can impact it, such as weather, time of day, and road conditions.

For proper execution, please ensure that the datasets are placed within the .\data\[dataset_name]\dataset.npy. Ensure that the datasets adhere to the following structure: (num_samples, num_nodes, input_dim).

For the PEMS08 dataset, dataset.npy file can be generated using the following code

data = np.load('./data/PEMS08/pems08.npz')['data']
np.save('./data/PEMS08/dataset.npy', data)

Edge Features

For detailed information on how we create edge attributes, please refer to Appendix D of our paper, where we provide an extensive discussion and introduction on it. Additionally, you may customize the edge attribute creation by implementing your own method, as an alternative to the Pearson correlation or the Time-delayed Dynamic Time Warping (DTW) method used in our study.

If you prefer to follow our approach, here is an example code to generate the peacor_adj.npy file:

def get_peacor_adj(data_path, threshold, save=False):
    # Load the dataset
    data = np.load(data_path + 'train.npz')['data']
    print("Data shape:", data.shape)
    
    # Compute the Pearson correlation coefficient matrix
    peacor = torch.corrcoef(torch.Tensor(data[...,0]).permute(1, 0))
    
    # Apply threshold
    peacor[peacor < threshold] = 0
    peacor[torch.eye(peacor.shape[0], dtype=bool)] = 0

    # Normalize the coefficients
    nonzero_peacor = peacor[peacor != 0]
    p_min, p_max = nonzero_peacor.min(), nonzero_peacor.max()
    peacor[peacor != 0] = (nonzero_peacor - p_min) / (p_max - p_min)

    # Visualization
    plt.figure(dpi=100)
    sns.heatmap(peacor)
    plt.show()
    
    # Save the result
    if save:
        np.save(data_path + 'peacor_adj.npy', peacor)

For reproducibility, we also provide peacor_adj.npy, sparse_adj.npy, and dist_adj.npy in the .\data\PEMS08\ directory for reference.

Arguments

We introduce some major arguments of our main function here.

Training settings:

  • mode: indicating the mode, e.g., train or test
  • gpu: using which GPU to train our model
  • seed: the random seed for experiments
  • dataset: which dataset to run
  • base_lr: the learning rate at the beginning
  • lr_decay_ratio: the ratio of learning rate decay
  • batch_size: training or testing batch size
  • seq_len: the length of historical steps
  • horizon: the length of future steps
  • input_dim: the dimension of inputs
  • output_dim: the dimension of inputs
  • max_epochs: the maximum of training epochs
  • patience: the patience of early stopping
  • save_preds: whether to save prediction results
  • train_ratio: the training ratio
  • val_ratio: the evaluastion ratio

Model hyperparameters:

  • hid_dim: the hidden dimensions in CaST
  • dropout: dropout rate
  • n_envs: the number of environments
  • node_embed_dim: the dimensionality of node embeddings
  • K: the depth of HL Deconfounder block

Training and Evaluation

The following examples are conducted on the datasets:

# PEMS08
python ./experiments/cast/main.py --dataset PEMS08 --mode 'train' --hid_dim 64 --n_envs 20 --node_embed_dim 5 --K 2
# AIR-BJ
python ./experiments/cast/main.py --dataset AIR_BJ --mode 'train' --hid_dim 64 --n_envs 10 --node_embed_dim 10 --K 3
# AIR-GZ
python ./experiments/cast/main.py --dataset AIR_GZ --mode 'train' --hid_dim 64 --n_envs 20 --node_embed_dim 5 --K 2

Code Reference

HL-HGAT: https://github.com/JH-415/HL-HGAT

VQVAE: https://github.com/ritheshkumar95/pytorch-vqvae

Citation

If you find our work useful in your research, please cite:

@article{xia2023deciphering,
  title={Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment},
  author={Xia, Yutong and Liang, Yuxuan and Wen, Haomin and Liu, Xu and Wang, Kun and Zhou, Zhengyang and Zimmermann, Roger},
  journal={arXiv preprint arXiv:2309.13378},
  year={2023}
}

cast's People

Contributors

yutong-xia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cast's Issues

Number of edges

Thanks for sharing the code.

I find the code predefined the number of edges for each dataset. Because the predefined number of edges cannot match the number of links in the adjacent matrix file, I'd like to know if the number includes self-loop. Does the code use the "Gaussian distance" to filter some links in the adjacent matrix file?

Much appreciated : )

How to get dataset?

Hello,
Thank you for this project and contribution.
When I try to run the code, I encounter problems with missing datasets such as "adj_mx_pems08.pkl", "time_dalay_attr.pkl", "sparse_adj.npy", "peacor_adj.npy" and "dist_adj.npy". how do I get these dataset?
Thanks!

how to get dataset?

How can I get these three files: dist_adj.npy, peacor_adj.npy, sparse_adj.npy? And also dataset.npy described in the dataset part? When I click on the link of PEMS-08, I can only get the files: distance.csv and pems08.npz.
I guess these four files are not raw data, but the authors generated. CaST is a very impressive work, and I plan to use it as a baseline model in my research. I hope to get your reply!

All dependent versions of the program

Hello, can you provide a copy of the program's dependent version requirements.txt file to the repository? It is very difficult to find a working version dependency, such as torch and torch-scatter version correspondence is very complex, let alone PyG version

Dataset

how to get dataset(/pems8/dataset.npy, peacor_adj.npy ....)?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.