Git Product home page Git Product logo

bert4rec's Introduction

BERT4Rec

Usage

Requirements

  • python 2.7+
  • Tensorflow 1.12 (GPU version)
  • CUDA compatible with TF 1.12

Run

For simplicity, here we take ml-1m as an example:

./run_ml-1m.sh

include two part command: generated masked training data

python -u gen_data_fin.py \
    --dataset_name=${dataset_name} \
    --max_seq_length=${max_seq_length} \
    --max_predictions_per_seq=${max_predictions_per_seq} \
    --mask_prob=${mask_prob} \
    --dupe_factor=${dupe_factor} \
    --masked_lm_prob=${masked_lm_prob} \
    --prop_sliding_window=${prop_sliding_window} \
    --signature=${signature} \
    --pool_size=${pool_size} \

train the model

CUDA_VISIBLE_DEVICES=0 python -u run.py \
    --train_input_file=./data/${dataset_name}${signature}.train.tfrecord \
    --test_input_file=./data/${dataset_name}${signature}.test.tfrecord \
    --vocab_filename=./data/${dataset_name}${signature}.vocab \
    --user_history_filename=./data/${dataset_name}${signature}.his \
    --checkpointDir=${CKPT_DIR}/${dataset_name} \
    --signature=${signature}-${dim} \
    --do_train=True \
    --do_eval=True \
    --bert_config_file=./bert_train/bert_config_${dataset_name}_${dim}.json \
    --batch_size=${batch_size} \
    --max_seq_length=${max_seq_length} \
    --max_predictions_per_seq=${max_predictions_per_seq} \
    --num_train_steps=${num_train_steps} \
    --num_warmup_steps=100 \
    --learning_rate=1e-4

hyper-parameter settings

json in bert_train like bert_config_ml-1m_64.json

{
  "attention_probs_dropout_prob": 0.2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.2,
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 200,
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "type_vocab_size": 2,
  "vocab_size": 3420
}

Reference

@inproceedings{Sun:2019:BSR:3357384.3357895,
 author = {Sun, Fei and Liu, Jun and Wu, Jian and Pei, Changhua and Lin, Xiao and Ou, Wenwu and Jiang, Peng},
 title = {BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer},
 booktitle = {Proceedings of the 28th ACM International Conference on Information and Knowledge Management},
 series = {CIKM '19},
 year = {2019},
 isbn = {978-1-4503-6976-3},
 location = {Beijing, China},
 pages = {1441--1450},
 numpages = {10},
 url = {http://doi.acm.org/10.1145/3357384.3357895},
 doi = {10.1145/3357384.3357895},
 acmid = {3357895},
 publisher = {ACM},
 address = {New York, NY, USA}
} 

bert4rec's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert4rec's Issues

Code isn't aviable

This paper refer me to this repo. I don't see any code in this repo. Can you update this repo? Thank you

How to use User features(such as age, demographics) to construct the Sequential Recommendation System ?

I want to train a next-procedure prediction for the medical dataset. After checking the data processing pipeline, what I have figured out is: the emphasis is done on the sequence of item interaction for users along with the timestamp in the loaded data. But there is no consideration of user features.
But, what I know is user features and item features also play a role, In my case, user features are very critical. How I can use such features to design the sequential recommendation system? Thank you!

TF 2.0

Can you upgrade your code to TensorFlow 2.0?

Learning rate

I modified lr=1e-3 (1e-4 in your paper) and found that the convergence speed is much faster. Similar phenomena also appeared in other code reproduction repositories, e.g, cr_1 and cr_2.

PyTorch Implementation

Hi, I'm wondering if you also have a PyTorch implementation of the BERT4REC available. Thank you very much!

预测速度很慢的原因是啥?怎么解决?

bert4rec在我们的场景下使用的时候,预测速度非常慢。
这个原因是不是拿了所有物料做softmax导致的。
但是作为cloze任务,这个又无法避免,想问下有什么解决办法吗

cannot duplicate result

I used the original code as well as the ml-1m data listed in the repo. But the loss didn't drop since 10000 iteration, final loss remained on about 5.12. Is there any preprocessing do I need to do before training? Thanks!

pytorch implementation

I'm wondering if you also have a PyTorch implementation of the BERT4REC available. Thank you very much!

Why is the result of NDCG not stored in the eval_result.txt file?

After running run_beauty.sh,I got eval_result.txt.
The eval_result.txt only included masked_lm_accuracy and masked_lm_loss.
I tried and found that NDCG and other evaluation metrics only outputed in the info log.
Could you please tell me how can I get these evaluation metrics results and store them in the result txt file?
I found the source codes as followed:
` #tf.logging.info('special eval ops:', special_eval_ops)
result = estimator.evaluate(
input_fn=eval_input_fn,
steps=None,
hooks=[EvalHooks()])

    output_eval_file = os.path.join(FLAGS.checkpointDir,
                                    "eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
        tf.logging.info("***** Eval results *****")
        tf.logging.info(bert_config.to_json_string())
        writer.write(bert_config.to_json_string()+'\n')
        for key in sorted(result.keys()):
            tf.logging.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))`

Data preprocessing?

Hello. Thanks for uploading the code.

But can you also upload how you preprocessed the data?

I would like to reproduce your results from raw data.

Thank you.

Data generation not deterministic

masked_lm_prob, max_predictions_per_seq, vocab, random.Random(random.randint(1,10000)),

Your data generation code is not deterministic, hence making it difficult to reproduce your result.

As shown in the referenced code, create_instances_threading() receive random.Random(random.randint(1, 10000)) as rng, which makes it undeterministic.

Please reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.