Git Product home page Git Product logo

corrstn's Introduction

CorrSTN

CorrSTN.png

Requirements

  • python 3.6
  • numpy == 1.19.5
  • minpy == 1.2.5
  • setuptools == 57.5.0
  • scikit-learn
  • pytorch == 1.7.0

create CorrSTN env

conda create -n CorrSTN python=3.6

install pytorch

conda install pytorch=1.7.0 torchvision torchaudio cudatoolkit=11.0 -c pytorch

install package for the maximal information coefficient

pip install setuptools==57.5.0
pip install minepy==1.2.5

install other packages

conda install scikit-learn
pip install tensorboardX

Train and Test

Step 1: Process dataset: In the root folder, prepare training, test and validation data.

python prepareData.py --config configurations/HZME_OUTFLOW_rdw.conf

In the lib folder, prepare SCorr. After the computing, move the result to each dataset folder.

python corr.py

Step 2: train or test the model:

python train_CorrSTN.py --config configurations/HZME_OUTFLOW_rdw.conf

prepare each dataset

python prepareData.py --config configurations/PEMS07.conf
python prepareData.py --config configurations/PEMS07_rdw.conf

python prepareData.py --config configurations/PEMS08.conf
python prepareData.py --config configurations/PEMS08_rdw.conf

python prepareData.py --config configurations/HZME_INFLOW.conf
python prepareData.py --config configurations/HZME_INFLOW_rdw.conf

python prepareData.py --config configurations/HZME_OUTFLOW.conf
python prepareData.py --config configurations/HZME_OUTFLOW_rdw.conf

train each dataset

python train_CorrSTN.py --config configurations/PEMS07.conf
python train_CorrSTN.py --config configurations/PEMS07_rdw.conf

python train_CorrSTN.py --config configurations/PEMS08.conf
python train_CorrSTN.py --config configurations/PEMS08_rdw.conf

python train_CorrSTN.py --config configurations/HZME_INFLOW.conf
python train_CorrSTN.py --config configurations/HZME_INFLOW_rdw.conf

python train_CorrSTN.py --config configurations/HZME_OUTFLOW.conf
python train_CorrSTN.py --config configurations/HZME_OUTFLOW_rdw.conf

Model

The trained model wiil be stored in experiments/$DataSetName$ folder, such as MAE_CorrSTN_h1d1w0_layer4_head8_dm64_channel1_dir2_drop0.00_1.00e-03_B16_K5_TcontextScaledSAtSE1TE.

We also supply our trained model in experiments folder, such as MAE_CorrSTN_h1d1w0_layer4_head8_dm64_channel1_dir2_drop0.00_1.00e-03_B16_K5_TcontextScaledSAtSE1TE-14.70-25.60-48.39, where the last three digits are the metrics of MAE, RMSE and MAPE. Furthermore we also supply the training logs for each model.

HOW TO TEST:

  1. delete the last three digits from our trained model folder, such as MAE_CorrSTN_h1d1w0_layer4_head8_dm64_channel1_dir2_drop0.00_1.00e-03_B16_K5_TcontextScaledSAtSE1TE-14.70-25.60-48.39 to MAE_CorrSTN_h1d1w0_layer4_head8_dm64_channel1_dir2_drop0.00_1.00e-03_B16_K5_TcontextScaledSAtSE1TE

  2. uncomment the last line in train_CorrSTN.py,

# train_main()
predict_main(0, test_loader, test_target_tensor, _max, _min, 'test')

and change 0 to the epoch number to be tested.

Results

results

Some points to note

  1. the batch sizes are different between the training phase and test and validation phase.

In the lib/utils.py,

line 537: val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size * 32)
line 550: test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size * 32)

so the batch size of test and validation phases is 32 times the batch size of training phases, which can improve the test speed.

corrstn's People

Contributors

drownfish19 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.