Git Product home page Git Product logo

aaltd18's Introduction

Data augmentation using synthetic data for time series classification with deep residual networks

This is the companion repository for our paper titled "Data augmentation using synthetic data for time series classification with deep residual networks". This paper has been accepted for an oral presentation at the Workshop on Advanced Analytics and Learning on Temporal Data (AALTD) 2018 in the European Conference on Machine Learning and Principles and Practice of Knowledge Discovery in Databases (ECML/PKDD) 2018.

architecture resnet

Data

The data used in this project comes from the UCR archive, which contains the 85 univariate time series datasets we used in our experiements.

Code

The code is divided as follows:

  • The distance folder contains the DTW distance in Cython instead of pure python in order to reduce the running time.
  • The dba.py file contains the DBA algorithm.
  • The utils folder contains the necessary functions to read the datasets and visualize the plots.
  • The knn.py file contains the K nearest neighbor algorithm which is mainly used when computing the weights for the data augmentation technique.
  • The resnet.py file contains the keras and tesnorflow code to define the architecture and train the deep learning model.
  • The augment.py file contains the method that generates the random weights (Average Selected) with a function that does the actual augmentation for a given training set of time series.

Prerequisites

All python packages needed are listed in utils/pip-requirements.txt file and can be installed simply using the pip command for python3.6.

Results

The main contribution of a data augmentation technique is to improve the performance (accuracy) of a deep learning model especially for time series datasets with small training sets such as the DiatomSizeReduction (the smallest in the UCR archive) where we managed to increase the model's accuracy from 30% (without data augmentation) to 96% with data augmentation for a residual network architecture.

Meat DiatomSizeReduction
plot-meat-dataset plot-diatomsizereduction-dataset

Reference

If you re-use this work, please cite:

@InProceedings{IsmailFawaz2018,
  Title                    = {Data augmentation using synthetic data for time series classification with deep residual networks},
  Author                   = {Ismail Fawaz, Hassan and Forestier, Germain and Weber, Jonathan and Idoumghar, Lhassane and Muller, Pierre-Alain},
  Booktitle                = {International Workshop on Advanced Analytics and Learning on Temporal Data, {ECML} {PKDD}},
  Year                     = {2018}
}

Acknowledgement

We would like to thank NVIDIA Corporation for the Quadro P6000 grant and the Mésocentre of Strasbourg for providing access to the cluster.

aaltd18's People

Contributors

forestier avatar hfawaz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

aaltd18's Issues

Why do the following errors occur?How?

Traceback (most recent call last):
File "main.py", line 1, in
from utils.constants import UNIVARIATE_ARCHIVE_NAMES as ARCHIVE_NAMES
File "/opt/module/datas/tsc_data_augment/utils/constants.py", line 3, in
from distances.dtw.dtw import dynamic_time_warping as dtw
ModuleNotFoundError: No module named 'distances.dtw.dtw'

Error with "./utils/build-cython.sh"

Thanks a lot for providing the code for this data augmentation, and for the time series benchmarking :)

I could not get your Cython build to go through as you were probably missing these lines from your setup.py, as I got this error:

dtw.c:579:10: fatal error: numpy/arrayobject.h: No such file or directory
 #include "numpy/arrayobject.h"
          ^~~~~~~~~~~~~~~~~~~~~
compilation terminated.

I changed your setup.py to this, and it seems to be working:

from distutils.core import setup
from Cython.Build import cythonize
import numpy

setup(
      ext_modules=cythonize("dtw.pyx"),
      include_dirs=[numpy.get_include()]
)

ImportError: cannot import name 'dynamic_time_warping'

Hi there,

I've successfully installed distances/dtw/setup.py. But I still got the import error in the title when running main.py. Is there anyone got the same problem?
BTW, I'm under anaconda virtual env, not sure if this causing problem.

Thanks

OSError: Unable to open file (unable to open file...)

Hi,there.I'm getting this error.
I had already tried to update h5py Version to 3.1.0.numpy Version: 1.19.5,but nothing works,every suggestion will be appreciated.

Traceback (most recent call last):
File "main.py", line 134, in
classifier_ensemble = Classifier_ENSEMBLE(output_dir, x_train.shape[1:], nb_classes, False)
File "/home/k8s-master/Desktop/aaltd18-master/aaltd18-master/ensemble.py", line 11, in init
+'best_model.hdf5')
File "/usr/lib/python3/dist-packages/keras/models.py", line 234, in load_model
with h5py.File(filepath, mode='r') as f:
File "/usr/local/lib/python3.6/dist-packages/h5py/_hl/files.py", line 427, in init
swmr=swmr)
File "/usr/local/lib/python3.6/dist-packages/h5py/_hl/files.py", line 190, in make_fid
fid = h5f.open(name, flags, fapl=fapl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5f.pyx", line 96, in h5py.h5f.open
OSError: Unable to open file (unable to open file: name = '/home/uha/hfawaz-datas/dl-tsc/results/resnet/UCR_TS_Archive_2015/Coffee/best_model.hdf5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

ValueError: Buffer has wrong number of dimensions (expected 2, got 1)

Thanks for sharing this work.

Is there any multivariate support? when Try to augment my data ( w/ shape (samples, length, n_features) ) I see the library is passing (samples, n_features) to dist_fun which results in this error

aaltd18/dba.py in calculate_dist_matrix(tseries, dist_fun, dist_fun_params)
     10         for j in range(i+1,N):
     11             y = tseries[j]
---> 12             dist = dist_fun(x,y,**dist_fun_params)[0]
     13             # because dtw returns the sqrt
     14             dist = dist*dist
dtw.pyx in dtw.dynamic_time_warping()
ValueError: Buffer has wrong number of dimensions (expected 2, got 1)

what is the correct input type? (x and y both having shape (1,)

problem in augment.py

Hi,
I found that there might be a problem in function get_weights_average_selected which in augment.py.

In line 57, final_neighbors_idx = np.random.permutation(k)[:subk] What if random indexs consist of the idx_center ? It might cover the idx_center so that the weights vector may not include the weight of 0.5 of the init_dba.

Therefore I added new code after the line.
while idx_center in topk_idx[final_neighbors_idx]:
final_neighbors_idx = np.random.permutation(k)[:subk]

What do you think?

Parallel implementation?

Hi again,

If I understood your algorithm correctly, the nb_prototypes_per_class could be made parallel, right?

image

Have you explored this with any of the options there for example:
https://stackoverflow.com/questions/9786102/how-do-i-parallelize-a-simple-python-loop or https://pypi.org/project/pp/

As for example I have ~2,000 time series in my own dataset (1981 samples in each of them), and the computational time comes quite unreasonable even with 160 time series subset:

913 seconds (15 minutes) for 40 time series
14,449 seconds (4 hours) for 160 time series

Definition of warping window [DTW_PARAMS]

Hi again @hfawaz et al.

I was wondering if you used actually the following setting for your paper, i.e. you did not warp at all if I understood the code correctly and read the paper properly?

DTW_PARAMS = {'w':-1} warping window should be given in percentage (negative means no warping window)

In the paper [6, Forestier et al. 2017, PDF] cited, that was given as following for the "old-school" Windows Warping (WW) method

The Windows Warping (WW) method discussed in [11] involves warping a randomly selected slice
of a time series by speeding it up or down. We follow the recommendations given in [11]
and use warping ratios equal to either 1/2 or 2 on slices representing 10% of time series’s length

image

Did you test out the warping with DTW?

can it be used to generate mixture of categorical and continues data?

can it be used to generate mixture of categorical and continues data?
for example at time t1 we have observation: red, 2.4 , 5, 12.456 and time t2: green, 3.5, 2, 45.78; time t3: black, 5.6, 7, 23.56; t4: red, 2.1, 5, 12.6 ?

when each dimension is correlated time series and each one can be categorical and continues data?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.