Git Product home page Git Product logo

-rnn-covid19_case_predict's Introduction

Python Spyder NumPy Pandas scikit-learn TensorFlow love

[RNN] What's for TOMORROW? -COVID-19_Case_Predict-

As slow as it takes I am finally pick up a dying trend (or maybe not) : COVID-19 Case Prediction Model! Predictions are important especially when it comes to anticipating a crisis e.g. food supply shortage, disease breakout, earthquake/volcanic eruptions etc. Dataset is readily available.

[MODEL UPDATE]

Reduce computation power by reducing nodes in network from 64-32 to 8-4 nodes in hidden layer 1 and hidden layer 2 respectively, model performance is similar:

model_lessnode eval_test_plot_lessnode

Model optimization : Reduce MAPE to 0.09 by reducing window size to 7.

eval_test_plot_win7

Model optimization : Reduce MAPE to 0.11 by increase training epoch to 1000.

eval_test_plot_1000epo

Model Performance

Model is able to achieve Mean Absolute Percentage Error (MAPE) of value 0.14: (MSE-Mean Squared Error, MAE-Mean Absolute Error)

eval_test

Actual vs. Predicted

eval_test_plot

Loss and MAPE Plot

tensorboard

Model Architecture

RNN model is constructed using 2 LSTM layers with activation function='tanh', having 64 and 32 nodes respectively:

model

Other hyperparameters

Hyperparameter optimizer,loss function,metrics are summarised below:

model.compile(optimizer='adam',loss='mse',metrics=['mse','mae','mape'])

tb=TensorBoard(log_dir=LOG_PATH)

hist=model.fit(x_train,y_train,batch_size=64,epochs=100,callbacks=tb,verbose=1)

Data Summary

Train and test dataset are available in separate files, having shape of (680,31) and (100,31) respectively, and have following features:

'date', 'cases_new', 'cases_import', 'cases_recovered', 'cases_active','cases_cluster', 'cases_unvax', 'cases_pvax', 'cases_fvax','cases_boost', 'cases_child', 'cases_adolescent', 'cases_adult','cases_elderly', 'cases_0_4', 'cases_5_11', 'cases_12_17','cases_18_29', 'cases_30_39', 'cases_40_49','cases_50_59','cases_60_69', 'cases_70_79', 'cases_80', 'cluster_import','cluster_religious', 'cluster_community', 'cluster_highRisk','cluster_education', 'cluster_detentionCentre', 'cluster_workplace', whereby 'cases_new' is the target feature.

Data Inspection / Cleaning

Train Dataset

Upon inspection 'cases_new' in train dataset contain non-numeric values as such:

inspec_info

To convert non-numeric values to null values the following code is appplied: df['cases_new']=pd.to_numeric(df['cases_new'],errors='coerce')

12 null values observed in 'cases_new' are represented as white stripes in figure below:

visual_msno_na

The following code is appplied to impute null values in time series data:

df['cases_new']=df['cases_new'].interpolate(method='linear')

Before After
visual_na visual_interpolated
Notice the little gaps on line plot (left figure) dissapeared after interpolated (right figure).

Test Dataset

1 null value is identified and imputed using pandas.DataFrame.interpolate() method.

Preprocessing

Train and test datasets are preprocessed in a similar fashion.

Normalize data

To normalize target feature:

df_scaled=mms.fit_transform(np.expand_dims(df['cases_new'],axis=-1))`
df_test_scaled=mms.transform(df_test_ori)

Create train and test datasets

Window size of 30 is selected to create train and test datasets, these datasets consist of rows of consecutive 30-values list:

x_train=[]
y_train=[]

for i in range(win_size,np.shape(df_scaled)[0]):
    x_train.append(df_scaled[i-win_size:i])
    y_train.append(df_scaled[i])
    
# Increase dimension to fit into RNN model
x_train=np.array(x_train)
y_train=np.array(y_train)

For test dataset:

# Concatenate scaled train and test dataset
concat_df=np.concatenate([df_scaled,df_test_scaled],axis=0)

# Create test dataset
concat_test=concat_df[-130:]

x_test=[]

for i in range(win_size,np.shape(concat_test)[0]):
    x_test.append(concat_test[i-win_size:i])

x_test=np.array(x_test)

Discussion

From the prediction plot we could see the model tries to keep up with the trend of real cases with minor 'lagness', and tends to represent as 'moving averages' when real cases fluctuates. Proceeding with the spike of real cases the performance of model reduced even more and not able to capture the rapid and fluctuations of real cases. As such, the model still has room for improvement although achieving evaluation MAPE value of 0.14. Some suggestions to improve include:

  • Increase number of epoch.
  • Reduce window size.
  • Implement Bidirectional() LSTM layer.
  • Apply moving average in data preprocessing step to smoothen the fluctuations.
  • Apply lrDecay on Adam optimizer.
  • Modify model architecture as such to incorporate CNN pooling layers.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.