Git Product home page Git Product logo

wenjiedu / saits Goto Github PK

View Code? Open in Web Editor NEW
302.0 5.0 49.0 616 KB

The official PyTorch implementation of the paper "SAITS: Self-Attention-based Imputation for Time Series". A fast and state-of-the-art (SOTA) deep-learning neural network model for efficient time-series imputation (impute multivariate incomplete time series containing NaN missing data/values with machine learning). https://arxiv.org/abs/2202.08516

Home Page: https://doi.org/10.1016/j.eswa.2023.119619

License: MIT License

Python 96.96% Shell 3.04%
time-series imputation-model missing-values self-attention partially-observed-data partially-observed-time-series partially-observed interpolation time-series-imputation incomplete-data

saits's Introduction

SAITS Title

powered by PyTorch

Tip

[Updates in Jun 2024] 😎 The 1st comprehensive time-seres imputation benchmark paper TSI-Bench: Benchmarking Time Series Imputation now is public available. The code is open source in the repo Awesome_Imputation. With nearly 35,000 experiments, we provide a comprehensive benchmarking study on 28 imputation methods, 3 missing patterns (points, sequences, blocks), various missing rates, and 8 real-world datasets.

[Updates in May 2024] 🔥 We applied SAITS embedding and training strategies to iTransformer, FiLM, FreTS, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer, Non-stationary Transformer, Pyraformer, Reformer, SCINet, RevIN, Koopa, MICN, TiDE, and StemGNN in PyPOTS to enable them applicable to the time-series imputation task.

[Updates in Feb 2024] 🎉 Our survey paper Deep Learning for Multivariate Time Series Imputation: A Survey has been released on arXiv. We comprehensively review the literature of the state-of-the-art deep-learning imputation methods for time series, provide a taxonomy for them, and discuss the challenges and future directions in this field.

‼️Kind reminder: This document can help you solve many common questions, please read it before you run the code.

The official code repository is for the paper SAITS: Self-Attention-based Imputation for Time Series (preprint on arXiv is here), which has been accepted by the journal Expert Systems with Applications (ESWA) [2022 IF 8.665, CiteScore 12.2, JCR-Q1, CAS-Q1, CCF-C]. You may never have heard of ESWA, while it was ranked 1st in Google Scholar under the top publications of Artificial Intelligence in 2016 (info source), and is still the top 1 AI journal according to Google Scholar metrics (here is the current ranking list FYI).

SAITS is the first work applying pure self-attention without any recursive design in the algorithm for general time series imputation. Basically you can take it as a validated framework for time series imputation, like we've integrated 2️⃣0️⃣❗️ forecasting models into PyPOTS by adapting SAITS framework. More generally, you can use it for sequence imputation. Besides, the code here is open source under the MIT license. Therefore, you're welcome to modify the SAITS code for your own research purpose and domain applications. Of course, it probably needs a bit of modification in the model structure or loss functions for specific scenarios or data input. And this is an incomplete list of scientific research referencing SAITS in their papers.

🤗 Please cite SAITS in your publications if it helps with your work. Please star🌟 this repo to help others notice SAITS if you think it is useful. It really means a lot to our open-source research. Thank you! BTW, you may also like PyPOTS for easily modeling your partially-observed time-series datasets.

Important

📣 Attention please:

SAITS now is available in PyPOTS, a Python toolbox for data mining on POTS (Partially-Observed Time Series). An example of training SAITS for imputing dataset PhysioNet-2012 is shown below. With PyPOTS, easy peasy! 😉

👉 Click here to see the example 👀
# Data preprocessing. Tedious, but PyPOTS can help.
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar
from pypots.data import load_specific_dataset
data = load_specific_dataset('physionet_2012')  # PyPOTS will automatically download and extract it.
X = data['X']
num_samples = len(X['RecordID'].unique())
X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_ori = X  # keep X_ori for validation
X = mcar(X, 0.1)  # randomly hold out 10% observed values as ground truth
dataset = {"X": X}  # X for model input
print(X.shape)  # (11988, 48, 37), 11988 samples and each sample has 48 time steps, 37 features

# Model training. This is PyPOTS showtime.
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_ffn=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset)  # train the model on the dataset
imputation = saits.impute(dataset)  # impute the originally-missing values and artificially-missing values
indicating_mask = np.isnan(X) ^ np.isnan(X_ori)  # indicating mask for imputation error calculation
mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)
saits.save("save_it_here/saits_physionet2012.pypots")  # save the model for future use
saits.load("save_it_here/saits_physionet2012.pypots")  # reload the serialized model file for following imputation or training


☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!

❖ Motivation and Performance

⦿ Motivation: SAITS is developed primarily to help overcome the drawbacks (slow speed, memory constraints, and compounding error) of RNN-based imputation models and to obtain the state-of-the-art (SOTA) imputation accuracy on partially-observed time series.

⦿ Performance: SAITS outperforms BRITS by 12% ∼ 38% in MAE (mean absolute error) and achieves 2.0 ∼ 2.6 times faster training speed. Furthermore, SAITS outperforms Transformer (trained by our joint-optimization approach) by 2% ∼ 19% in MAE with a more efficient model structure (to obtain comparable performance, SAITS needs only 15% ∼ 30% parameters of Transformer). Compared to another SOTA self-attention imputation model NRTSI, SAITS achieves 7% ∼ 39% smaller mean squared error (above 20% in nine out of sixteen cases), meanwhile, needs much fewer parameters and less imputation time in practice. Please refer to our full paper for more details about SAITS' performance.

❖ Brief Graphical Illustration of Our Methodology

Here we only show the two main components of our method: the joint-optimization training approach and SAITS structure. For the detailed description and explanation, please read our full paper Paper_SAITS.pdf in this repo or on arXiv.

Training approach

Fig. 1: Training approach

SAITS architecture

Fig. 2: SAITS structure

❖ Citing SAITS

If you find SAITS is helpful to your work, please cite our paper as below, ⭐️star this repository, and recommend it to others who you think may need it. 🤗 Thank you!

@article{du2023saits,
title = {{SAITS: Self-Attention-based Imputation for Time Series}},
journal = {Expert Systems with Applications},
volume = {219},
pages = {119619},
year = {2023},
issn = {0957-4174},
doi = {10.1016/j.eswa.2023.119619},
url = {https://arxiv.org/abs/2202.08516},
author = {Wenjie Du and David Cote and Yan Liu},
}

or

Wenjie Du, David Cote, and Yan Liu. SAITS: Self-Attention-based Imputation for Time Series. Expert Systems with Applications, 219:119619, 2023.

😎 Our latest survey and benchmarking research on time-series imputation may be also related to your work:

@article{du2024tsibench,
title={TSI-Bench: Benchmarking Time Series Imputation},
author={Wenjie Du and Jun Wang and Linglong Qian and Yiyuan Yang and Fanxing Liu and Zepu Wang and Zina Ibrahim and Haoxin Liu and Zhiyuan Zhao and Yingjie Zhou and Wenjia Wang and Kaize Ding and Yuxuan Liang and B. Aditya Prakash and Qingsong Wen},
journal={arXiv preprint arXiv:2406.12747},
year={2024}
}
@article{wang2024deep,
title={Deep Learning for Multivariate Time Series Imputation: A Survey},
author={Jun Wang and Wenjie Du and Wei Cao and Keli Zhang and Wenjia Wang and Yuxuan Liang and Qingsong Wen},
journal={arXiv preprint arXiv:2402.04059},
year={2024}
}

🔥 In case you use PyPOTS for your research, please also cite the following paper:

@article{du2023pypots,
title={{PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series}},
author={Wenjie Du},
journal={arXiv preprint arXiv:2305.18811},
year={2023},
}

or

Wenjie Du. PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series. arXiv, abs/2305.18811, 2023.

❖ Repository Structure

The implementation of SAITS is in dir modeling. We give configurations of our models in dir configs, provide the dataset links and preprocessing scripts in dir dataset_generating_scripts. Dir NNI_tuning contains the hyper-parameter searching configurations.

❖ Development Environment

All dependencies of our development environment are listed in file conda_env_dependencies.yml. You can quickly create a usable python environment with an anaconda command conda env create -f conda_env_dependencies.yml.

❖ Datasets

For datasets downloading and generating, please check out the scripts in dir dataset_generating_scripts.

❖ Quick Run

Generate the dataset you need first. To do so, please check out the generating scripts in dir dataset_generating_scripts.

After data generation, train and test your model, for example,

# create a dir to save logs and results
mkdir NIPS_results

# train a model
nohup python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    > NIPS_results/PhysioNet2012_SAITS_best.out &

# during training, you can run the blow command to read the training log
less NIPS_results/PhysioNet2012_SAITS_best.out

# after training, pick the best model and modify the path of the model for testing in the config file, then run the below command to test the model
python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    --test_mode

❗️Note that paths of datasets and saving dirs may be different on personal computers, please check them in the configuration files.

❖ Acknowledgments

Thanks to Ciena, Mitacs, and NSERC (Natural Sciences and Engineering Research Council of Canada) for funding support. Thanks to Ciena for providing computing resources. Thanks to all our reviewers for helping improve the quality of this paper. And thank you all for your attention to this work.

✨Stars/forks/issues/PRs are all welcome!

👏 Click to View Stargazers and Forkers:

Stargazers repo roster for @WenjieDu/SAITS

Forkers repo roster for @WenjieDu/SAITS

❖ Last but Not Least

If you have any additional questions or have interests in collaboration, please take a look at my GitHub profile and feel free to contact me 😃.

saits's People

Contributors

wenjiedu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

saits's Issues

Training stage of an attention based model!

Greetings Wenjie,
I was very much impressed aby your work "SAITS". I am trying to create an attention-based model on my own as a part of my Bacholer's project and I have a few questions to ask:
I wanted to know how many trials did you run the SAITS in the training stage? I noticed you've already answered a similar issue to user to the user "Rajesh90123" where you said "just let the experiment run till I thought it was good to stop". Just for reference, could you please tell em how many trials did you run during the hyperparameter tuning till you realised it was good enough to stop? Was it in range of 100s, 1000s, 10,000s or more till you thought it was good enough? Please let me know at the earliest of your convenience. Thank you!

pd.concat

FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.

Few more questions on SAITS working method

I am really impressed with your work: "SAITS: SELF-ATTENTION-BASED IMPUTATION FOR TIME SERIES".
I was studying your code and paper and I have few questions:

  1. Is there any particular reason for not using learning rate scheduler?
  2. Also, when I tried replicating the code, I found out that my program was running indefinitely during random search for hyperparameter tuning especially with the use of loguniform in learning rate. I didn't find any lines regarding maximum number of trials for hyperparameter tuning in your code or in your paper "SAITS: SELF-ATTENTION-BASED IMPUTATION FOR TIME SERIES". Could you please provide information regarding this?
  3. Do you change the value of artificial missing rate for MIT in training section based on missing rate at test? What I mean here is that, for example, if you have to test your model on test dataset with 80% missing rate, will the MIT missing rate be fixed to 20% as you have done in the code provided or do you train the entire model again by changing MIT missing rate manually to 80% in the training code?
  4. Do you perform hyperparameter tuning on different missing rate in validation data or do you perform hyperparameter tuning on a particular missing rate at validation dataset and save it to use it for test dataset with any missing rate? What I mean here is, running hyperparameter tuning on a validation dataset of 20% missing rate and saving the best model, and using the same best model to impute missing data for test dataset with 80% missing rate?
    Thank you, and if my questions are not clear, please comment, and I will try my best to describe my question. English is my second language, and I am not fluent on speaking English.

Custom dataset

Thank you for your wonderful work and I would like to know whether I can use this model or train a model from scratch to imputate my time series?

Some questions about multivariate time series Imputation.

Thank you for your work,I recently read your paper SAITS: Self-Attention-based Imputation for Time Series. I am also doing the work related to multivariate time series Imputation. I have some questions, and I hope to communicate with you.
1.I recently used your method to run the data set I used. My data processing approach is first divided into training set and test set, and then build time sequence, first use the train set to train, and then use the test set test (but I know data Imputation algorithm is unsupervised algorithm and did not use the true information of the missing data, there are some people who divided the test set training set, while there are also some people didn't,There are some differences in the results of your algorithm between these two data set partitioning methods,May I ask how do you view the partitioning of data sets?)
2. May I ask whether your algorithm will have overfitting, because the loss of back propagation is the MAE of unmissing items, not the MAE of the whole data set. I feel that with the increase of training times, it will gradually tend to be overfitting
3. Now the stopping condition of the algorithm is to reach the specified epoch. The epoch of different data sets need to be detected,If we divide the test set and the training set, can we quit the training by judging that the missing item data MAE of the training set reaches the minimum.
Thank you very much

Calculation of the loss function

Thank you for excellent work!

The imputation loss of MIT is not covered the complement feature vector in the code.

Secondly, the paper also talks about taking the raw data X without artificially-masking as input to the MIT formula, and I found in the corresponding code that you used the manual masked X^.

Is there something I don't understand. I look forward to your resolution of my doubts!

Question about output of the first DMSA

Hello, I want to ask you about the saits.py part of the modeling in your code, I only used the first DMSA module, I also entered the X and Miss Mask in your way, but after going through the encoder layer, data becomes all Nan, what is the reason for this situation.
Looking forward to your reply

Loss_MIT wrong?

I saw that the loss of MIT computaion in core.py was
'MIT_loss = self.customized_loss_func(
X_tilde_3, inputs["X_ori"], inputs["indicating_mask"]
)'
,which computed the MAE between M~3 and X_ori and differed to the paper.

Question about temporal dependencies and feature correlations captured by DMSA

你好,关于文章当中的自注意力我有问题想请教您。维度为N×N的自注意力矩阵Q·Kt,表示的是长度为N的一种维度之间的注意力关系,而您文章中提到的“Such a mechanism makes DMSA able to capture the temporal dependencies and feature correlations between time steps in the high dimensional space with only one attention operation”,DMSA的一个注意力矩阵能一次性同时捕获到两种维度之间的注意力,想问一次注意力操作捕获到两种类型的注意力是怎么做到的。

Question about training, validation, testing and missing rates?

Greetings, sir.
In your research paper, "SAITS: Self-Attention-based Imputation for Time Series", you have presented a table with different missing rates (20% to 90%). So, do you prepare best models from each val_dataset with different missing % rates so as to provide result for corresponding missing % of test or do you save model based on specific missing percentage in validation data (say 20% loss in val data) and use it to prepare result for test for other missing percentage(say 30% artificial missing percentage in test data)?

window truncate function

def window_truncate(feature_vectors, seq_len):
    """ Generate time series samples, truncating windows from time-series data with a given sequence length.
    Parameters
    ----------
    feature_vectors: time series data, len(shape)=2, [total_length, feature_num]
    seq_len: sequence length
    """
    start_indices = np.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
    sample_collector = []
    for idx in start_indices:
        sample_collector.append(feature_vectors[idx: idx + seq_len])

    return np.asarray(sample_collector).astype('float32')

Wenjie,

I have some questions if you do not mind to clarify

  1. In the implementation, is the training data generated by diving into the time series based on the sequence length?
  2. What is the advantage of such training data configuration over using the sliding window approach, e.g., generates the training set with one-time step lag [t-n, t-n+1, ... t], [t-n+1, t-n+2, ... t+1], [t-n+2, t-n+3, ... t+2]. Is not the sliding window approach would generate more datasets for training?
  3. I am not quite familiar with transformer architecture. In a typical RNN based imputation method, there are the concepts of sequence length (i.e., length of historical or future data for input) and prediction horizon (i.e., how far in the future or in the past the model try to impute). For the SAITS, what would be the equivalent concepts or does such a concept of the prediction horizon exist?
  4. I understand from your paper that the sequence length is fixed between different models for comparison purposes. How does the sequence length affect the accuracy of the imputation? What would you recommend to determine the appropriate sequence length for the problem at hand?
  5. An unrelated question, Is your PyPOTS currently working with the Air Quality dataset?

Thanks in advance,
Haochen

训练步长可以动态调整吗?

你好,根据给出的example.py中的saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10),可以看到n_steps被设置为48,因为example中给定的数据集中每个RecordID都有48个样本。
但我的数据集中每个RecordID对应的样本数是不固定的,比如1个,7个,甚至216个,这样的话我把n_steps参数设置为最大的RecordID对应数目,比如216,这会是可行的吗?或者有没有其它方案。十分感谢!

Test data

Hi,

After certain modification and inclusion of code snippets I was able train, validate and get the mae for test data.
I want to obtain the de-normalized value after the imputation happens in test data, both predicted and actual. Can you help?

How to comprehend the NNI finetunning?

First, the file SAITS_basic_config.ini under NNI_tuning folder miss 2 args: "MIT" & "ORT", which influence the script python ../../run_models.py --config_path SAITS_basic_config.ini --param_searching_mode running. You may add this two args in the .ini file and also check for other .ini files if you have time.
Second, i am wondering how to check the help of nni for tuning the parameters. To be more specific, which parameters did nni change? When and how much did the parameters change? Only parameters listed in SAITS_searching_space.json file will be changed?

Thanks for your attention.

Question about hyperparameter optimization

Hello, I would like to ask you about the hyperparameter optimization for the model. In your file NNI_tuning/SAITS/SAITS_searching_config.yml, you described the settings for the hyperparameters and the training command, which also includes a JSON file. However, when I tried to run the command for hyperparameter optimization on SAITS, I encountered an error: "No option 'mit' in section: 'training'". I supplemented the missing parameters and ran it again, but I only obtained the parameters set in the SAITS_basic_config.ini file. Could you please advise me on how to iterate through the parameters in the JSON file to obtain the optimal parameters?

Question about MAE

Hi, Wenjie

def masked_mae_cal(inputs, target, mask):
    """ calculate Mean Absolute Error"""
    return torch.sum(torch.abs(inputs - target) * mask) / (torch.sum(mask) + 1e-9)

I have a little doubt about the calculation of MAE.
I found you normalizes the dataset with standard scaling, it means the target and input are standard normalized. So why not calculate MAE after inverse the scaling to them?

Using CSV files versus h5 data

Hello Wenjie,

Thank you for releasing the code, I had couple of questions. I am trying to run the code using Air Quality dataset in Google Colab. These are some of my doubts:

  1. !CUDA_VISIBLE_DEVICES=2 python run_models.py --config_path configs/AirQuality_SAITS_best.ini
    Running this gives me the following error message.
    OSError: Unable to open file (unable to open file: name = 'dataset_generating_scripts/RawData/AirQuality/PRSA_Data_20130301-20170228/datasets.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

All the dataset is in .CSV format.
1a) Is there an option to use the default .csv data?
1b) How do I convert .csv to h5 format?

  1. Where should we change the file path of the dataset for training purpose
    As in file configs/AirQuality_SAITS_best.ini ?

Please do let me know, thanks.

Niharika

关于生成数据的一些问题

您好,在代码复现过程中遇到了一些问题,希望您能解答。
1.插补生成的数据与原始数据一样,在ETT测试集中,根据missing_mask可以得知(0,3)这个点应该是手动缺失,但最后生成的值和原始值一摸一样,excel是数据标准化后导出的文件,h5文件是生成的文件。
Snipaste_2024-06-03_14-56-24

2.如果采用自己的数据集,训练过程中的config文件是要自己定义吗?其次config文件夹中的best是自己不断训练出来的吗?

3.最后插补生成的h5文件我想用来进行下游预测工作,可以转为csv导出吗?

Configs of ETTm1

Hello,
Could you share the configuration settings on the ETTm1 dataset?
Thanks!

Question about loss

Thank you for your work and please understand that it is not a direct question about the code. Is there any reason the loss function does not include the classification error term? Some models that perform reconstruction and imputation are include classification error in the loss function. Have you ever trained models in this way? If so, please let me know what the results were like.

Final error calculation

Hello Wenjie,

I have a doubt regarding the calculation of the final error metrics on the test data.

Suppose my sample data looks like this:

date          A       B
timestamp1    3       5
timestamp2    4       7
timestamp3    6       8
timestamp4    8       10

After introducing 50% missingness :

date          A       B
timestamp1    Nan     5
timestamp2    Nan     7
timestamp3    6       8
timestamp4    Nan    Nan

After imputation :

date          A       B
timestamp1    2       5
timestamp2    4       7
timestamp3    6       8
timestamp4    6       5
  1. The MAE, RMSE, and MRE are calculated only on the imputed values or on the whole dataset?
  2. Can you explain the MAE, RMSE, and MRE formulas/ equations used.

Thank you, Please let me know

Regards
Niharika Joshi

Inquiry About Using SAITS Model with 'physionet_2012' Dataset

Hi everyone:

While using the 'physionet_2012' dataset, I found that the 'y' values cannot be output as expected. I noticed that the article's 'Downstream classification task' section addresses relevant content. If the 'In-hospital_death' values are not 0 and 1 but some other numbers, and the 'X' section content influences the mortality rate, can the SAITS model be used to predict this? If so, which function of SAITS should I use?

Thank you for your assistance.

Best regards,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.