Git Product home page Git Product logo

pygru4rec's Introduction

pyGRU4REC


Notice on the potential bug(06/12/2018)


Update(05/20/2018)

  • PyTorch 0.4.0 support
    • The code is now compatible with PyTorch >= 0.4.0
  • Code cleanup
    • Removed redundant pieces of code thanks to simpler API of PyTorch 0.4.0
    • Improved the readablility of the confusing rnn updating routine
    • Improved the readability of the training/testing routine
  • Optimization
    • Testing code is now much faster than before

Environment

  • PyTorch 0.4.0
  • Python 3.6.4
  • pandas 0.22.0
  • numpy 1.14.0

Usage

Training / Test Set Specifications

  • Filenames
    • Training set should be named as train.tsv
    • Test set should be named as test.tsv
  • File Paths
    • train.tsv, test.tsv should be located under the data directory. i.e. data/train.tsv, data/test.tsv
  • Contents
    • train.tsv, test.tsv should be the tsv files that stores the pandas dataframes that satisfy the following requirements(without headers):
      • The 1st column of the tsv file should be the integer Session IDs
      • The 2nd column of the tsv file should be the integer Item IDs
      • The 3rd column of the tsv file should be the Timestamps

Training/Testing using Jupyter Notebook

See example.ipynb for the full jupyter notebook script that

  1. Loads the data
  2. Trains & tests a GRU4REC model
  3. Loads & tests a pretrained GRU4REC model

Training & Testing using run_train.py

  • Before using run_train.py, I highly recommend that you to take a look at example.ipynb to see how the implementation works in general.
  • Default parameters are the same as the TOP1 loss case in the GRU4REC paper.
  • Intermediate models created from each training epochs will be stored to models/, unless specified.
  • The log file will be written to logs/train.out.
$ python run_train.py > logs/train.out

Args:
    --loss_type: Loss function type. Should be one of the 'TOP1', 'BPR', 'CrossEntropy'.(Default: 'TOP1')
    --model_name: The prefix for the intermediate models that will be stored during the training.(Default: 'GRU4REC')
    --hidden_size: The dimension of the hidden layer of the GRU.(Default: 100)
    --num_layers: The number of layers for the GRU.(Default: 1)
    --batch_size: Training batch size.(Default: 50)
    --dropout_input: Dropout probability of the input layer of the GRU.(Default: 0)
    --dropout_hidden: Dropout probability of the hidden layer of the GRU.(Default: .5)
    --optimizer_type: Optimizer type. Should be one of the 'Adagrad', 'RMSProp', 'Adadelta', 'Adam', 'SGD'(Default: 'Adagrad')
    --lr: Learning rate for the optimizer.(Default: 0.01)
    --weight_decay: Weight decay for the optimizer.(Default: 0)
    --momentum: Momentum for the optimizer.(Default: 0)
    --eps: eps parameter for the optimizer.(Default: 1e-6)
    --n_epochs: The number of training epochs to run.(Default: 10)
    --time_sort: Whether to sort the sessions in the dataset in a time order.(Default: False)

Reproducing the results of the original paper

  • The results from this PyTorch Implementation gives a slightly better result compared to the original code that was written in Theano. I guess this comes from the difference between Theano and PyTorch & the fact that dropout has no effect in my single-layered PyTorch GRU implementation.
  • The results were reproducible within only 2 or 3 epochs, unlike the original Theano implementation which runs for 10 epochs by default.
$ bash run_train.sh

pygru4rec's People

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.