This repository contains the code of a PaddlePaddle 2.2 implementation of STGCN based on the paper Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting, with a few modifications in the model architecture to tackle with traffic jam forecasting problems.
Forecasting traffic jams is not that similar to forecasting traffic flow, which we have a redundant amount of data. However, jams only occur in the largest cities and often only during peak hours, resulting in a unbalanced dataset. In order to study the jam patterns, we select roads in Haidian District, Beijing that have much more jam than others.
Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting https://arxiv.org/abs/1709.04875
Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907 (GCN)
Inductive Representation Learning on Large Graphs https://arxiv.org/abs/1706.02216 (GraphSAGE)
Graph Attention Networks https://arxiv.org/abs/1710.10903 (GAT)
Bag of Tricks for Node Classification with Graph Neural Networks https://arxiv.org/pdf/2103.13355.pdf (BoT)
Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting https://ojs.aaai.org//index.php/AAAI/article/view/3881 (ASTGCN)
The original STGCN model facilitates 1-st order ChebyConv and GCN as the graph operation. In our model we conducted experiments on one spectral method(GCN) and two spatial methods(GAT, GraphSAGE)
Graph Neural Networks often suffer from oversmoothing: as the model becomes deeper, the representations of nodes tend to become similar to each other due to being repeatedly aggregated. Adding a residual connection mitigates oversmoothing by adding the input unsmoothed features directly to the output of graph convolution operation. Furthermore, the connection helps against gradient instablities.
Jam status often follow daily patterns. In order to let the model learn historical patterns, we directly feed the model historical jam data with the same hour aligned. For example, if we want to predict the traffic status at 8PM. 30, Nov, 2021, we feed the model the 8PM traffic status in the past 12 days directly through a graph convolution layer, then concat it with the output of the S-T convolution blocks to generate the input of the final classifying layer.
The model is implemented to predict the jam status of several future time steps. First we feed the input data into the model to generate prediction of the first future time step. Then we concat the predicted status with the original input, feed to the model to generate the prediction of the next time step and so on.
The original STGCN model was a regression model, optimizing a mean squared loss. Our traffic jam status has four classes: 1 -- smooth traffic; 2 -- temperate jam; 3 -- moderate jam; 4 -- heavy jam. So we changed it into a softmax with cross entropy classification model. Because in most of the cases, the traffic are smooth which makes label 1 dominates the others. We use a weighted cross entropy loss to punish incorrect classifications of 2, 3 and 4 more serverely.
You can use pip to install the requirements:
sh requirement.sh
Right now I am still updating the experiments, adding new blocks, trying out new ideas.
Here's a example of what a training epoch should look like right now, the numbers are the cross entropy loss of the model. The current model achieves an accuracy over 80% in prediction of traffic jams of the next 4 hours.
I think the self-distillation methods has their potential in imporving long term prediction by first allowing the teacher method to learn historic patterns to make a raw prediction of whether a road segment is jammed or not, then the student model utilizes graph structured data to use the nearby traffic conditions to make a more educated guess of the future jam status.
Also pretraining techniques also has their potential in generating spatial-temporal representations for road segments.