Git Product home page Git Product logo

st-metanet's Introduction

ST-MetaNet: Urban Traffic Prediction from Spatio-Temporal Data Using Deep Meta Learning

Overview of ST-MetaNet

This is the MXNet implementation of ST-MetaNet in the following paper:


Requirements for Reproducibility

System Requirements:

  • System: Ubuntu 16.04
  • Language: Python 3.5.2
  • Devices: a single GTX 1080 GPU

Library Requirements:

  • scipy == 1.2.1
  • numpy == 1.16.3
  • pandas == 0.24.2
  • mxnet-cu90 == 1.5.0b20190108
  • dgl == 0.2
  • tables == 3.5.1
  • pymal
  • h5py

Dependency can be installed using the following command:

pip install -r requirements.txt

After that, change the backend deep learning framework for dgl:

echo 'export DGLBACKEND=mxnet' >> ~/.bashrc
. ~/.bashrc

Data Preparation

Unzip the data files in:

  • flow-prediction/data.zip
  • traffic-prediction/data.zip (the data is pre-processed in repository DCRNN)

Description of Flow Data

The flow data is collected from 32x32 grids of Beijing city.

  • BJ_FEATURE.h5: the shape of data is (32, 32, 989), which indicates (row, colume, feature_id). This data represents the node feature for each grid. It consists of POI features and road features saved in BJ_POI.h5 and BJ_ROAD.h5, respectively.
  • BJ_FLOW.h5: the shape of data is (150, 24, 32, 32, 2), which indicates (date, hour, row, colume, flow_type). The flow_type represents inflow or outflow of the region.
  • BJ_GRAPH.h5: the shape is (1024, 1024, 32), which indicates (grid_1, grid_2, feature_id). This data represents the edge feature from grid_1 to grid_2.
  • BJ_POI.h5: POI features for each grid. This data is the intermediate output of preprocessing stage (not used in the model training & testing).
  • BJ_ROAD.h5: Road features for each grid. This data is the intermediate output of preprocessing stage (not used in the model training & testing). .

Description of Traffic Data

The description please refers to the repository of DCRNN.


Model Training & Testing

Given the flow prediction task as example (the traffic prediction task is exactly the same as the flow prediction task):

  1. cd flow-prediction/.
  2. The settings of the models are in the folder src/model_setting, saved as yaml format. Three models are provided: seq2seq, gat-seq2seq, and st-metanet. Other baselines refers to DCRNN and ST-ResNet, respectively.
  3. All trained model will be saved in param/. There are two types of files in this folder:
    1. model.yaml: the model training log (the result on evaluation dataset of each epoch). This file records the number of the best epoch for the model.
    2. model-xxxx.params: the saved model parameters of the best evaluation epoch, where xxxx is the epoch number.
  4. Running the codes:
    1. cd src/ .
    2. python train.py --file model_settting/[model_name].yaml --gpus [gpu_ids] --epochs [num_epoch]. The code will firstly load the best epoch from params/, and then train the models for [num_epoch]. Our code can be trained with multiple gpus. An example of [gpu_ids] is 0,1,2,3 if you have four gpus. But we recommend to use a single gpu to train & evaluate the model if possible (currently, our implementation using DGL library for meta graph attention is not efficient when using multiple gpus).
  5. Training from the begining:
    1. Remove the model records in param/, otherwise the code will train from the best pre-trained (saved) model.
    2. Train the model (example).
      1. Single gpu: python train.py --file model_setting/st-metanet.yaml --gpus 0 --epochs 200.
      2. Multiple gpu: python train.py --file model_setting/st-metanet.yaml --gpus 0,1 --epochs 200.
  6. Testing the model (example): python train.py --file model_setting/st-metanet.yaml --gpus 0 --epochs 0. The code will directly give evaluation results on evaluation dataset and test dataset, without training.

Citation

If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper:

  • Zheyi Pan, Yuxuan Liang, Weifeng Wang, Yong Yu, Yu Zheng, and Junbo Zhang. Urban Traffic Prediction from Spatio-Temporal Data Using Deep Meta Learning. 2019. In The 25th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD'19), August 4–8, 2019, Anchorage, AK, USA.

License

ST-MetaNet is released under the MIT License (refer to the LICENSE file for details).

st-metanet's People

Contributors

panzheyi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

st-metanet's Issues

Run errors with "DeferredInitializationError"

"mxnet.gluon.parameter.DeferredInitializationError: Parameter 'seq2seq_encoder_g0_in_mlp0_dense0_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers."

When running the codes, I met the above initialization problem. How can I debug it?

AttributeError: 'NDArray' object has no attribute 'device'

Hi, thank you for sharing your code!
When I run the cmd line:
python train.py --file model_setting/st-metanet.yaml --gpus 0 --epochs 200
The Traceback message likes:

/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/dgl/heterograph.py:72: DGLWarning: Recommend creating graphs by dgl.graph(data) instead of dgl.DGLGraph(data).
dgl_warning('Recommend creating graphs by dgl.graph(data)'
Traceback (most recent call last):
File "train.py", line 172, in
main(args)
File "train.py", line 150, in main
metrics = [MAE(scaler), RMSE(scaler), IndexMAE(scaler, [0,1,2]), IndexRMSE(scaler, [0,1,2])],
File "train.py", line 72, in fit
self.process_data(epoch, train, metrics)
File "train.py", line 53, in process_data
outputs = [self.net(*x, is_training) for x in zip(*inputs)]
File "train.py", line 53, in
outputs = [self.net(*x, is_training) for x in zip(*inputs)]
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/mxnet/gluon/block.py", line 540, in call
out = self.forward(*args)
File "/home/ypy/rec_sys/ST-MetaNet/traffic-prediction/src/model/seq2seq.py", line 233, in forward
states = self.encoder(feature, data)
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/mxnet/gluon/block.py", line 540, in call
out = self.forward(*args)
File "/home/ypy/rec_sys/ST-MetaNet/traffic-prediction/src/model/seq2seq.py", line 52, in forward
_data = _data + g(data, feature)
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/mxnet/gluon/block.py", line 540, in call
out = self.forward(*args)
File "/home/ypy/rec_sys/ST-MetaNet/traffic-prediction/src/model/graph.py", line 70, in forward
g = self.get_graph_on_ctx(state.context)
File "/home/ypy/rec_sys/ST-MetaNet/traffic-prediction/src/model/graph.py", line 66, in get_graph_on_ctx
self.build_graph_on_ctx(ctx)
File "/home/ypy/rec_sys/ST-MetaNet/traffic-prediction/src/model/graph.py", line 60, in build_graph_on_ctx
g.edata['dist'] = self.dist.as_in_context(ctx)
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/dgl/view.py", line 227, in setitem
self._graph._set_e_repr(self._etid, self._edges, {key: val})
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/dgl/heterograph.py", line 4135, in _set_e_repr
if F.context(val) != self.device:
File "/home/ypy/rec_sys/env_sb/lib/python3.6/site-packages/dgl/backend/pytorch/tensor.py", line 99, in context
return input.device
AttributeError: 'NDArray' object has no attribute 'device'

Could you please help me fix this~
Waiting for your reply. Thanks a lot~

Question regarding Seq2Seq input data

Hello, I'm having a question reading the input of Seq2Seq model. In the paper you mentioned that "The features of nodes are added as the additional inputs." From my understanding on the last dimension the input should have a size of: 9 (node features) + 8 (neighbors) * 2 (x,y coordinates). I searched everywhere in the MyGRUCell and I can not find such shape, furthermore I can see that the features are not used in the cell.

Thank you in advance and looking forward to hearing from you.

OSError: libcudart.so.9.0: cannot open shared object file: No such file or directory

您好,很感谢您分享代码。
在运行示范命令
python train.py --file model_setting/st-metanet.yaml --gpus 0 --epochs 200
遇到报错信息如下:
Traceback (most recent call last):
File "train.py", line 7, in
import mxnet as mx
File "/home/cll/anaconda3/envs/MetaNet/lib/python3.5/site-packages/mxnet/init.py", line 24, in
from .context import Context, current_context, cpu, gpu, cpu_pinned
File "/home/cll/anaconda3/envs/MetaNet/lib/python3.5/site-packages/mxnet/context.py", line 24, in
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
File "/home/cll/anaconda3/envs/MetaNet/lib/python3.5/site-packages/mxnet/base.py", line 213, in
_LIB = _load_lib()
File "/home/cll/anaconda3/envs/MetaNet/lib/python3.5/site-packages/mxnet/base.py", line 204, in _load_lib
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
File "/home/cll/anaconda3/envs/MetaNet/lib/python3.5/ctypes/init.py", line 347, in init
self._handle = _dlopen(self._name, mode)
OSError: libcudart.so.9.0: cannot open shared object file: No such file or directory
环境安装,除了mxnet_cu90,剩下的lib都是严格对照requirements.txt。mxnet_cu90从https://dist.mxnet.io/python/cu90网站下载临近的两个whl文件,mxnet_cu90-1.5.0b20190106-py2.py3-none-manylinux1_x86_64.whl 和mxnet_cu90-1.5.0b20190113-py2.py3-none-manylinux1_x86_64.whl。分别测试后, 都会出现上述报错,请问该如何Debug~
谢谢!

Modeling and Training of EMK/NMK Learner

I have read the source code and the idea of EMK/NMK learner is interesting.

But I can't find them in the source data. It seems that the training process of both two learners is not provided in the source, and the training results have been saved in the "FEATURE.h5" file.

Would you provide the training strategy of the embedding vectors with edge/node meta knowledge if it is convenient for you?

About BJ_FEATURE.h5

Hello, I have a question about the original data. In the grid feature data 'BJ_FEATURE.h5', each dimension of the feature value of 989 dimensions corresponds to what features, and 989 is not equal to 940 ‘BJ_POI’ + 47 ‘BJ_ROAD’, why?

Or if you can, could you share the 'BJ_FEATURE.h5' file before it was embedded?

Thank you very much and look forward to your reply!

Questions regarding the setup

Hello, thank you very much for sharing the code. I'm receiving an error in metric.py line 53: TypeError: unsupported operand type(s) for +: 'NoneType' and 'float').

Regarding update function in MetaGat

Hi,

Thank you for sharing the code.

In the article, you said that the new state h(i) of MetaGAT is calculated by a linear combination of the previous hidden state and the new state calculated with the attention mechanism. In other words:

Screenshot from 2020-11-10 17-36-12

However, in the code, I can only find the right side of the equation in the function msg_reduce. Is this enough or maybe a msg_update function should be added to apply the previous formula?

Thank you in advance,
Best regards.

Running error with sudden cancellation

I am deploying your model on Colab with several adaptations, whose link is here: https://colab.research.google.com/drive/1BZ9PFWz-61KiKB91ET93iJkm35lklCr-?usp=sharing

The model ran smoothly with Mxnet 1.4.0 and 1.5.1, but end up with a sudden cancellation (possibly due to the RAM surge).
Successfully loading the model st-metanet [epoch: 131] seq2seq_ ( Parameter seq2seq_encoder_c0_gru0_i2h_weight (shape=(192, 3), dtype=<class 'numpy.float32'>) Parameter seq2seq_encoder_c0_gru0_h2h_weight (shape=(192, 64), dtype=<class 'numpy.float32'>) Parameter seq2seq_encoder_c0_gru0_i2h_bias (shape=(192,), dtype=<class 'numpy.float32'>) Parameter seq2seq_encoder_c0_gru0_h2h_bias (shape=(192,), dtype=<class 'numpy.float32'>) Parameter seq2seq_encoder_c1_dense_z_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_w_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_w_dense2_bias (shape=(8192,), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_z_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_w_dense2_bias (shape=(8192,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_r_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense2_weight (shape=(4096, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_w_dense2_bias (shape=(4096,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_i2h_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense2_weight (shape=(4096, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_w_dense2_bias (shape=(4096,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_encoder_c1_dense_h2h_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_encoder_g0_graph_weight (shape=(1, 1), dtype=<class 'numpy.float32'>) Parameter seq2seq_encoder_g0_graph_mlp0_dense0_weight (shape=(16, 96), dtype=float32) Parameter seq2seq_encoder_g0_graph_mlp0_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_encoder_g0_graph_mlp0_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_encoder_g0_graph_mlp0_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_encoder_g0_graph_mlp0_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_encoder_g0_graph_mlp0_dense2_bias (shape=(8192,), dtype=float32) Parameter seq2seq_decoder_c0_gru0_i2h_weight (shape=(192, 3), dtype=<class 'numpy.float32'>) Parameter seq2seq_decoder_c0_gru0_h2h_weight (shape=(192, 64), dtype=<class 'numpy.float32'>) Parameter seq2seq_decoder_c0_gru0_i2h_bias (shape=(192,), dtype=<class 'numpy.float32'>) Parameter seq2seq_decoder_c0_gru0_h2h_bias (shape=(192,), dtype=<class 'numpy.float32'>) Parameter seq2seq_decoder_c1_dense_z_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_w_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_w_dense2_bias (shape=(8192,), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_z_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_w_dense2_bias (shape=(8192,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_r_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense2_weight (shape=(4096, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_w_dense2_bias (shape=(4096,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_i2h_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense2_weight (shape=(4096, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_w_dense2_bias (shape=(4096,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense0_weight (shape=(16, 32), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense2_weight (shape=(1, 2), dtype=float32) Parameter seq2seq_decoder_c1_dense_h2h_b_dense2_bias (shape=(1,), dtype=float32) Parameter seq2seq_decoder_g0_graph_weight (shape=(1, 1), dtype=<class 'numpy.float32'>) Parameter seq2seq_decoder_g0_graph_mlp0_dense0_weight (shape=(16, 96), dtype=float32) Parameter seq2seq_decoder_g0_graph_mlp0_dense0_bias (shape=(16,), dtype=float32) Parameter seq2seq_decoder_g0_graph_mlp0_dense1_weight (shape=(2, 16), dtype=float32) Parameter seq2seq_decoder_g0_graph_mlp0_dense1_bias (shape=(2,), dtype=float32) Parameter seq2seq_decoder_g0_graph_mlp0_dense2_weight (shape=(8192, 2), dtype=float32) Parameter seq2seq_decoder_g0_graph_mlp0_dense2_bias (shape=(8192,), dtype=float32) Parameter decoder0_proj_weight (shape=(2, 96), dtype=float32) Parameter decoder0_proj_bias (shape=(2,), dtype=float32) Parameter geo_encoder_dense0_weight (shape=(32, 989), dtype=float32) Parameter geo_encoder_dense0_bias (shape=(32,), dtype=float32) Parameter geo_encoder_dense1_weight (shape=(32, 32), dtype=float32) Parameter geo_encoder_dense1_bias (shape=(32,), dtype=float32) ) NUMBER OF PARAMS: 268224 INFO:root:Processing 1000 timestamps INFO:root:Processing 2000 timestamps tcmalloc: large alloc 9179439104 bytes == 0x55fd96a46000 @ 0x7f54362f31e7 0x7f5433000cf1 0x7f5433065768 0x7f5433065883 0x7f5433105b5e 0x7f54331063c4 0x7f5433106512 0x55fcee9530a4 0x55fcee952da0 0x55fcee9c7868 0x55fcee9c2235 0x55fcee95473a 0x55fcee9c6f40 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c2b0e 0x55fcee95465a 0x55fcee9c2b0e 0x55fcee95465a 0x55fcee9c2b0e 0x55fcee9c1c35 0x55fcee9c1933 0x55fceea8b402 0x55fceea8b77d 0x55fceea8b626 0x55fceea63313 0x55fceea62fbc 0x7f54350ddbf7 0x55fceea62e9a tcmalloc: large alloc 9179439104 bytes == 0x55ffba478000 @ 0x7f54362d5b6b 0x7f54362f5379 0x7f54064bde75 0x7f54064bdf0d 0x7f54064c622e 0x7f5405cdaa26 0x7f5405cdb298 0x7f5433825dae 0x7f543382571f 0x7f5433a395ac 0x7f5433a389e3 0x55fcee9537b2 0x55fcee9c76f2 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c2b0e 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c2b0e 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c393b 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c6f40 0x55fcee9c1c35 0x55fcee95473a 0x55fcee9c2b0e 0x55fcee95465a 0x55fcee9c2b0e 0x55fcee95465a ^C

date range

Hi,in your paper,the Taxi flow dataset is range from Feb 1st 2015 to Jun. 2nd 2015,The number of days in this period is 122. less than 3600 timeslot.
but the BJ_FLOW.h5, shape is (150,24,32,32,2),which has 150 days?

Running error

Traceback (most recent call last):
File "train.py", line 172, in
main(args)
File "train.py", line 143, in main
train, eval, test, scaler = getattr(data.dataloader, dataset_setting['dataloader'])(settings)
File "/home/xinyilian/ST-MetaNet/flow-prediction/src/data/dataloader.py", line 109, in dataloader_all_sensors_seq2seq
return dataiter_all_sensors_seq2seq(train, scaler, setting),
File "/home/xinyilian/ST-MetaNet/flow-prediction/src/data/dataloader.py", line 90, in dataiter_all_sensors_seq2seq
feature = mx.nd.array(np.stack(feature)) # [B, N, D]
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/ndarray/utils.py", line 146, in array
return _array(source_array, ctx=ctx, dtype=dtype)
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2505, in array
arr[:] = source_array
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 449, in setitem
self._set_nd_basic_indexing(key, value)
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 715, in _set_nd_basic_indexing
self._sync_copyfrom(value)
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 881, in _sync_copyfrom
ctypes.c_size_t(source_array.size)))
File "/home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/base.py", line 253, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [19:21:13] src/ndarray/ndarray_function.cc:51: Check failed: size == to->Size() (-2000107520 vs. 2294859776) : copying size mismatch, from: 18446744065709121536 bytes, to: 9179439104 bytes.
Stack trace:
[bt] (0) /home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7feb9aa234cb]
[bt] (1) /home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x281c85b) [0x7feb9cd8f85b]
[bt] (2) /home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::NDArray::SyncCopyFromCPU(void const*, unsigned long) const+0x27c) [0x7feb9cd1b59c]
[bt] (3) /home/xinyilian/.conda/envs/xyl/lib/python3.6/site-packages/mxnet/libmxnet.so(MXNDArraySyncCopyFromCPU+0x2b) [0x7feb9ca9790b]
[bt] (4) /home/xinyilian/.conda/envs/xyl/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7fec1d551ec0]
[bt] (5) /home/xinyilian/.conda/envs/xyl/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7fec1d55187d]
[bt] (6) /home/xinyilian/.conda/envs/xyl/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x7fec1d767ede]
[bt] (7) /home/xinyilian/.conda/envs/xyl/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x13915) [0x7fec1d768915]
[bt] (8) python(_PyObject_FastCallDict+0x8b) [0x560bdc88ae3b]

Runing error.

Hi,

When i run this code I got the following error:


File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/train.py", line 179, in
main(args)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/train.py", line 157, in main
metrics = [MAE(scaler), RMSE(scaler), IndexMAE(scaler, [0,1,2]), IndexRMSE(scaler, [0,1,2])],
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/train.py", line 79, in fit
self.process_data(epoch, train, metrics)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/train.py", line 60, in process_data
outputs = [self.net(*x, is_training) for x in zip(*inputs)]
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/train.py", line 60, in
outputs = [self.net(*x, is_training) for x in zip(*inputs)]
File "/opt/mxnet/python/mxnet/gluon/block.py", line 360, in call
return self.forward(*args)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/model/seq2seq.py", line 233, in forward
states = self.encoder(feature, data)
File "/opt/mxnet/python/mxnet/gluon/block.py", line 360, in call
return self.forward(*args)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/model/seq2seq.py", line 52, in forward
_data = _data + g(data, feature)
File "/opt/mxnet/python/mxnet/gluon/block.py", line 360, in call
return self.forward(*args)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/model/graph.py", line 71, in forward
g = self.get_graph_on_ctx(state.context)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/model/graph.py", line 67, in get_graph_on_ctx
self.build_graph_on_ctx(ctx)
File "/user/cs.aau.dk/seany/SourceCode/traffic-prediction/src/model/graph.py", line 57, in build_graph_on_ctx
g = DGLGraph()
File "/usr/local/lib/python3.5/dist-packages/dgl/graph.py", line 907, in init
self._msg_index = utils.zero_index(size=self.number_of_edges())
File "/usr/local/lib/python3.5/dist-packages/dgl/utils.py", line 249, in zero_index
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu()))
File "/usr/local/lib/python3.5/dist-packages/dgl/backend/mxnet/tensor.py", line 151, in zeros
return nd.zeros(shape, dtype=dtype, ctx=ctx)
File "/opt/mxnet/python/mxnet/ndarray/utils.py", line 67, in zeros
return _zeros_ndarray(shape, ctx, dtype, **kwargs)
File "/opt/mxnet/python/mxnet/ndarray/ndarray.py", line 3387, in zeros
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
File "", line 34, in _zeros
File "/opt/mxnet/python/mxnet/_ctypes/ndarray.py", line 92, in _imperative_invoke
ctypes.byref(out_stypes)))
File "/opt/mxnet/python/mxnet/base.py", line 148, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: Invalid Input: 'int64', valid values are: {'float16', 'float32', 'float64', 'int32', 'uint8'}, in operator _zeros(name="", dtype="int64", ctx="cpu(0)", shape="(0,)")
srun: error: nv-ai-03.srv.aau.dk: task 0: Exited with exit code 1

So could you please give me some suggestions? Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.