Git Product home page Git Product logo

improving-trxl-for-commu's Introduction

header

Improving Transformer-XL for Music Generation

logo

This project was carried out by YAI 11th, in cooperation with POZAlabs.


Gmail NOTION REPORT



Improving Transformer-XL for Music Generation ๐ŸŽผ

YAI x POZAlabs ์‚ฐํ•™ํ˜‘๋ ฅ 1ํŒ€
NLP model for music generation

Members ๐Ÿ‘‹

์กฐ์ •๋นˆย  :ย  YAI 9thย  /ย [email protected]
๊น€๋ฏผ์„œย  :ย  YAI 8thย  /ย  [email protected]
๊น€์‚ฐย  :ย  YAI 9thย  /ย  [email protected]
๊น€์„ฑ์ค€ย  :ย  YAI 10thย  /ย [email protected]
๋ฐ•๋ฏผ์ˆ˜ย  :ย  YAI 9thย  /ย [email protected]
๋ฐ•์ˆ˜๋นˆย  :ย  YAI 9thย  /ย [email protected]



Getting Started ๐Ÿ”ฅ

As there are different models and metrics, we recommand using seperate virtual envs for each. As each directory contains it's own "Getting Started", for clear instructions, please follow the links shown in each section.

Improving-TrXL-for-ComMU/
โ”œโ”€ CAS/
โ”œโ”€ Group_Encoding/
โ”œโ”€ Soft_Labeling/
โ”œโ”€ TransformerCVAE/

As for Baseline which is Transformer-XL trained on ComMU-Dataset, refer to the ComMU-code by POZAlabs

Building on Transformer-XL ๐Ÿ—๏ธ

0. Baseline (Transformer-XL) - Link

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Evaluation

Classifcation Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.6159 0.0210
KEY 0.8781 0.8781 0
TIMESIGNATURE 0.9082 0.8925
PITCHRANGE 0.7483 0.7090 0.0525
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5923
GENRE 0.8532 0.8427 0.0123
MINVELOCITY 0.4718 0.4482
MAXVELOCITY 0.4718 0.4495
TRACKROLE 0.6500 0.5753 0.1149
RHYTHM 0.9934 0.9934

Normalized Mean CASD : 0.0401

1. Group Encoding - Link

For a vanila transformer-XL model, it inputs tokens in a 1d sequence and adds Positional Encoding to give the model information about the position between tokens. In this setting, the model learns about the semantics of the data as well as the structure of the MIDI data. However, as there is an explicit pattern when encoding MIDI data in to sequence of tokens, we propose a Group Encoding method that injects an inductive bias about the explicit structure of the token sequence to the model. This not only keeps the model from inferencing strange tokens in strange positions, it also allows the model to generate 4 tokens in a single feed forward, boosting the training speed as well as the inference speed of the model.

GE

GE_1

Evaluation

Controllability and Diversity

CP CV(Midi) CV(Note) CH Diversity
Transformer XL w/o GE 0.8585 0.8060 0.9847 0.9891 0.4100
Transformer XL w GE 0.8493 0.7391 0.9821 0.9839 0.4113

Classification Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.5910 0.0606
KEY 0.8781 0.8532 0.0284
TIMESIGNATURE 0.9082 0.8951
PITCHRANGE 0.7483 0.7195 0.0385
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5884
GENRE 0.8532 0.8532 0
MINVELOCITY 0.4718 0.4364
MAXVELOCITY 0.4718 0.4560
TRACKROLE 0.6500 0.5360 0.1754
RHYTHM 0.9934 0.9934

Normalized Mean CASD : 0.0605

Inference Speed

Inference time for Valset Inference speed per sample relative speed up
Transformer XL w/o GE 1189.4s 1.558s per sample X1
Transformer XL w GE 692.2s 0.907s per sample X1.718

Sampled Audio

5 note sequences with shared and different meta data were sampled by the following conditions and mixed together.

  • Shared meta data acrross 5 samples

    • audio_key : aminor
    • chord_progressions : [['Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D']]
    • time_signature : 4/4
    • genre : cinematic
    • bpm : 120
    • rhythm : standard
  • Different meta data for each instrument

    • riff string_violin mid_high standard
    • main_melody string_ensemble mid_high
    • sub_melody string_cello very_low
    • pad acoustic_piano mid_low
    • sub_melody brass_ensemble mid
sample.mov

2. Soft Labeling - Link

To prevent overfitting of the model, techniques such as soft labeling are often used. We apply soft labeling on velocity, duration, and position information, so it can be flexibly predicted. For example, if the target of the token value is 300, the logit is reconstructed by referring to the logit value of the 298/299/301/302 token. As a result of using soft labeling, we confirm that the token appears more flexible than baseline.

softlabeling

Evaluation

Test set NLL

n-2 n-1 n n+1 n+2 test NLL
0 0 1 0 0 0.96
0.1 0.1 0.6 0.1 0.1 1.01
0 0.15 0.7 0.15 0 1.05
0.1 0.2 0.4 0.2 0.1 1.26

Classification Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.6133 0.0251
KEY 0.8781 0.8741 0.0046
TIMESIGNATURE 0.9082 0.8990
PITCHRANGE 0.7483 0.7195 0.0385
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5740
GENRE 0.8532 0.8440 0.0108
MINVELOCITY 0.4718 0.4429
MAXVELOCITY 0.4718 0.4429
TRACKROLE 0.6500 0.5661 0.1291
RHYTHM 0.9934 0.9934

Normalized Mean CASD: 0.0416

Controllability and Diversity

CP CV(Midi) CV(Note) CH Diversity
Transformer XL w/o SL 0.8585 0.8060 0.9847 0.9891 0.4100
Transformer XL w SL 0.8807 0.8007 0.9861 0.9891 0.4134

3. Gated Transformer-XL - Link


Dataset ๐ŸŽถ

ComMU (POZAlabs)

ComMU-code has clear instructions on how to download and postprocess ComMU-dataset, but we also provide a postprocessed dataset for simplicity. To download preprocessed and postprocessed data, run

cd ./dataset && ./download.sh && cd ..

Metrics ๐Ÿ“‹

To evaluate generation models we have to generate data with trained models and depending on what metrics we want to use, the generation proccess differ. Please refer to the explanations below to generate certain samples needed for evaluation.

Generating Samples for Evaluation

for CAS we generate samples based on traing meta data and for Diversity & Controllability we generate samples based on validation meta data.

for transformer-XL with GE, use

Improving-TrXL-for-ComMU/
โ”œโ”€ Group_Encoding/
    โ”œโ”€ generate_GE.py

for transformer-XL baseline and SL, use

Improving-TrXL-for-ComMU/
โ”œโ”€ generate_SL.py

note that generate_SL.py should be placed inside ComMU-code as for SL does not change the model structure or inference mechanism.

Classification Accuracy Score - Link

Evaluating Generative Models is an open problem and for Music generation has not been well defined. Inspired by 'Classification Accuracy Score for Conditional Generative Models' we use CAS as an evaluation metric for out music generation models. THe procedure of our CAS is the following

  1. Train a Music Generation Model with ComMU train set which we call 'Real Dataset'
  2. Generate samples and form a dataset which we call 'Fake Dataset'
  3. Train a classification model with 'Real Dataset' which we call 'Real Model'
  4. Train a classification model with 'Fake Dataset' which we call 'Fake Model'
  5. For each lable (meta data) we compare the performance of 'Fake Model' and 'Real Model' on ComMU validation set

From the above procedure we can obtain CAS for a certain label (meta) we want to evaluate. If the difference between the accuracy of the 'Fake Model' and 'Real Model' is low, it means our generation model has captured the data distribution w.r.t the certain label well. For our experiments on vanila Transformer-XL, Transformer-XL with GE and Transformer-XL with SL, we calculate CAS on all 11 labels. However, some labels such as Number of Measure, Time Signature or Rhythm are usuited for evaluation. Therfore we select BPM, KEY, PITCH RANGE, GENRE and TRACK-ROLE and calculate the Normalized Mean Classification Accuracy Difference Score denoting it as CADS. We obtain CADS as the following.

GE

where N is the number of labels(meta) that we think are relevent, in this case 5, and $R_i$ and $F_i$ denotes Real model accuracy for label num i and fake model accuracy for label num i respectively.

The following figure is the overal pipeline of CAS

  • To compute Classicication Accuracy Score of Generated Music conditioned with certain meta data

to generate samples for SL and baseline, run

$ python generate_SL.py --checkpoint_dir {./model_checkpoint} --meta_data {./train_meta_data.csv} --eval_diversity {False} --out_dir {./train_out}

for GE, run

$ python generate_GE.py --checkpoint_dir {./checkpoint_best.pt} --meta_data {./train_meta_data.csv} --eval_diversity {False} --out_dir {./train_out}

to compute CAS for certain meta data as label, run

$ python evaluate_resnet.py --midi_dir {./data.npy} --meta_dir {./meta.npy} --meta_num {meta_num}

to compute CAS for all meta data as label, run

$ python evaluate_resnet_all.py --midi_dir {./data.npy} --meta_dir {./meta.npy}

Diversity & Controllability

  • To compute the Diversity of Generated Music conditioned with certain meta data

to generate samples for SL and baseline, run

$ python generate_SL.py --checkpoint_dir {./model_checkpoint} --meta_data {./val_meta_data.csv} --eval_diversity {True} --out_dir {./val_out}

for GE, run

$ python generate_GE.py --checkpoint_dir {./checkpoint_best.pt} --meta_data {./val_meta_data.csv} --eval_diversity {True} --out_dir {./val_out}

First, you should modifty eval_config.py after then,

to compute Diversity run,

$ python ./commu_eval/commu_eval_diversity.py

to compute Controllability,

$ python ./commu_eval/commu_eval_controllability.py

Skills

Frameworks

Citations

@misc{dai2019transformerxl,
      title={Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context}, 
      author={Zihang Dai and Zhilin Yang and Yiming Yang and Jaime Carbonell and Quoc V. Le and Ruslan Salakhutdinov},
      year={2019},
      eprint={1901.02860},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.1905.10887,
  doi = {10.48550/ARXIV.1905.10887},
  url = {https://arxiv.org/abs/1905.10887},
  author = {Ravuri, Suman and Vinyals, Oriol},
  keywords = {Machine Learning (cs.LG), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Classification Accuracy Score for Conditional Generative Models},
  publisher = {arXiv},
  year = {2019},
  copyright = {arXiv.org perpetual, non-exclusive license}
}
@inproceedings{hyun2022commu,
  title={Com{MU}: Dataset for Combinatorial Music Generation},
  author={Lee Hyun and Taehyun Kim and Hyolim Kang and Minjoo Ki and Hyeonchan Hwang and Kwanho Park and Sharang Han and Seon Joo Kim},
  booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  year={2022},
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.