Git Product home page Git Product logo

adv-alstm's Introduction

Adversarial Training Attentive LSTM

This model can be used for Binary Classification tasks and will be used below as a stock movement classifier (UP/DOWN).

Table of content

Getting started

Description

The project is all about reproducing a working version of Adversarial Attention Based LSTM for TensorFlow 2. This new version is available in AdvALSTM.py with the object AdvLSTM. More details of my version below.

I finally updated the author's code for it to run with TensorFlow 2.x and compare my results with his. The updated code is available in the folder original_code_updated .

│   AdvALSTM.py
│   preprocessing.py
│   replicate_result.py
├───data
│   └───stocknet-dataset
│       └───price
│           │   trading_dates.csv
│           ├───ourpped/
│           └───raw/
└───original_code_updated
        evaluator.py
        load.py
        pred_lstm.py
        __init__.py


Dependancies

  • TensorFlow : 2.9.2

How to use the Adv-ASLTM model?

Installation

Download the AdvASLTM.py file and place it in your project folder.

from AdvLSTM import AdvLSTM

Use

To create a AdvLSTM model, use :

model = AdvLSTM(
  units, 
  epsilon, 
  beta, 
  learning_rate = 1E-2, 
  dropout = None, 
  l2 = None, 
  attention = True, 
  hinge = True, 
  adversarial_training = True, 
  random_perturbations = False)

The AdvLSTM object is a subclass of tf.keras.Model. So you can easily train it as you would normally do with a TensorFlow 2 model :

model.fit(
  X_train, y_train, 
  validation_data = (X_validation, y_validation),
  epochs = 200, 
  batch_size = 1024
  )

The model only accepts:

  • y : labelled as binary classes (0 or 1), even when using Hinge loss !

(nb_sequences, )

  • x : sequences of length T, with n features.

(nb_sequences, T, n)

Documentation

class AdvALSTM.AdvALSTM(**params):
__init__(self, units, epsilon = 1E-3, beta =  5E-2, learning_rate = 1E-2, dropout = None, l2 = None, attention = True, hinge = True, adversarial_training = True, random_perturbations = False)
  • units : int (required)

    Specify the number of units of the layers (Dense, LSTM and Temporal Attention) contained in the model.

  • epsilon : float (optional, default : 1E-3)

  • beta : float (optional, default : 5E-2). If adversarial_training = True : Epsilon and Beta are used in the adversarial loss. Epsilon define the l2 norm of the perturbations that used to generate the Adversarial Examples :

    Formula e_adv Formula r_adv

    Beta is then used to weight the Adversarial loss generated with the Adversarial example following the formula below :

    Formula general loss
  • learning_rate : float (optional, default : 1E-2). Define the learning rate used to initialize the Adam optimizer used for training.

  • dropout : float (optional, default : 0.0).

  • l2 : float (optional, default : 1E-2). Define l2 regularization parameter

  • attention : boolean (optional, default : True). If True, the model will use the TemporalAttention layer after the LSTM layer to generate the Latent Space representation. If False, the model will take the last hidden state of the LSTM (with return_sequences = False)

  • hinge : boolean (optional, default : True). If True, the model will uses the Hinge loss to perform training. If False, the model use the Binary Cross-Entropy loss.

  • adversarial_training : boolean (optional, default : True). If True, the model will generate an Adversarial Loss from the Adversarial Example that will be added to the global loss. If False, the model will be training without adversarial example and loss.

  • random_perturbations : boolean (optional, default : False). Define how the perturbations are created. If False (default), the perturbations are generated following the paper guidelines with : Formula g_s gradient g is computed with tape.gradient(loss(y, y_pred), e). If True, the perturbations are randomly generated instead of being gradient oriented. g is computed with tf.random.normal(...)

Model description

Attention Model schema

The Adversarial Attentive LSTM is based on an Attentive LSTM is used to generate a latent space vector that is used as a 1D-representation of a 2D-input sequence (here, the last T technical indicators of a given stock).

This Attentive LSTM use a Temporal Attention Layer that "summarize" the hidden states of the LSTM following the temporal importance detected by the Neural Network. This layer keeps the last hidden states and append it to the attentive output.

Adv-ASLTM schema

Following the Attentive LSTM, we get $e_^{s}$ the latent space representation of the input sequence.

We pass it through the classifier to get $\hat{y}_^{s}$ which it then used to calculate the first loss.

This first loss is derived with respect to $e_^{s}$. It gives the "direction" to follow to maximize the loss by adding perturbations. We use this derivative to calculate $e_{adv}^{s}$, the Adversarial Example :

Formula e_adv

Formula r_adv

This Adversarial Example is then passed to the classifier to receive a second loss (Adversarial Loss) as below :

Formula general loss

With $\beta$ used to weight the adversarial loss.

adv-alstm's People

Contributors

clementperroud avatar

Stargazers

Jinling Wang avatar  avatar Haiyao avatar zhuhq8 avatar Lion M. avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.