Git Product home page Git Product logo

attentionxml's Introduction

AttentionXML

AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification

Requirements

  • python==3.7.4
  • click==7.0
  • ruamel.yaml==0.16.5
  • numpy==1.16.2
  • scipy==1.3.1
  • scikit-learn==0.21.2
  • gensim==3.4.0
  • torch==1.0.1
  • nltk==3.4
  • tqdm==4.31.1
  • joblib==0.13.2
  • logzero==1.5.0

Datasets

Download the GloVe embedding (840B,300d) and convert it to gensim format (which can be loaded by gensim.models.KeyedVectors.load).

We also provide a converted GloVe embedding at here.

XML Experiments

XML experiments in paper can be run directly such as:

./scripts/run_eurlex.sh

Preprocess

Run preprocess.py for train and test datasets with tokenized texts as follows:

python preprocess.py \
--text-path data/EUR-Lex/train_texts.txt \
--label-path data/EUR-Lex/train_labels.txt \
--vocab-path data/EUR-Lex/vocab.npy \
--emb-path data/EUR-Lex/emb_init.npy \
--w2v-model data/glove.840B.300d.gensim

python preprocess.py \
--text-path data/EUR-Lex/test_texts.txt \
--label-path data/EUR-Lex/test_labels.txt \
--vocab-path data/EUR-Lex/vocab.npy 

Or run preprocss.py including tokenizing the raw texts by NLTK as follows:

python preprocess.py \
--text-path data/Wiki10-31K/train_raw_texts.txt \
--tokenized-path data/Wiki10-31K/train_texts.txt \
--label-path data/Wiki10-31K/train_labels.txt \
--vocab-path data/Wiki10-31K/vocab.npy \
--emb-path data/Wiki10-31K/emb_init.npy \
--w2v-model data/glove.840B.300d.gensim

python preprocess.py \
--text-path data/Wiki10-31K/test_raw_texts.txt \
--tokenized-path data/Wiki10-31K/test_texts.txt \
--label-path data/Wiki10-31K/test_labels.txt \
--vocab-path data/Wiki10-31K/vocab.npy 

Train and Predict

Train and predict as follows:

python main.py --data-cnf configure/datasets/EUR-Lex.yaml --model-cnf configure/models/AttentionXML-EUR-Lex.yaml 

Or do prediction only with option "--mode eval".

Ensemble

Train and predict with an ensemble:

python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 0
python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 1
python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 2
python ensemble.py -p results/FastAttentionXML-Wiki-500K -t 3

Evaluation

python evaluation.py --results results/AttentionXML-EUR-Lex-labels.npy --targets data/EUR-Lex/test_labels.npy

Or get propensity scored metrics together:

python evaluation.py \
--results results/FastAttentionXML-Amazon-670K-labels.npy \
--targets data/Amazon-670K/test_labels.npy \
--train-labels data/Amazon-670K/train_labels.npy \
-a 0.6 \
-b 2.6

Reference

You et al., AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification, NeurIPS 2019

Declaration

It is free for non-commercial use. For commercial use, please contact Mr. Ronghi You and Prof. Shanfeng Zhu ([email protected]).

attentionxml's People

Contributors

yourh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

attentionxml's Issues

What's the difference between AttentionXML and fastAttentionXML?

Hi yourh !

I find 2 models in this repo, AttentionXML and fastAttentionXML, the latter of which seems not occurred in your paper. Also, I can't find where the PLT module is used in AttentionXML. Could you tell me the difference between these two models, and, where is the PLT modules in AttentionXML?

Thank you very much!

AttentionXML in production

Hi,

I have a question about using AttentionXML model in production. Do we have to use the same tokenizer used for training in POC or we can create another tokenizer and embedding matrix in production?

Thank you in advance

如何在自己的数据集上做cluster

Hi!
我想在自己的数据集上做cluster
但是出现了
ValueError: could not convert string to float: b'vckbee'
可能是因为我的数据集只有raw_text的
如何用raw text生成svmlight_file呢

Training Error on Amazon-670k

Hello,

After hours of training in Amazon-670k, I am getting the following error:

[I 230302 08:33:14 tree:145] Finish Training Level-1
[I 230302 08:33:14 tree:149] Generating Candidates for Level-2, Number of Labels: 16384, Top: 160
^MCandidates:   0%|          | 0/459301 [00:00<?, ?it/s]^MCandidates:   1%|          | 2517/459301 [00:00<00:18, 25168.82it/s]^MCandidates:   1%|          | 5013/459301 [00:00>
^MParents: 0it [00:00, ?it/s]^MParents: 3155it [00:00, 31545.29it/s]^MParents: 6310it [00:00, 31546.17it/s]^MParents: 9409it [00:00, 31375.56it/s]^MParents: 12587it [00:00, 31>
  File "main.py", line 98, in <module>
    main()
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 764, in __call__
    return self.main(*args, **kwargs)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "main.py", line 70, in main
    model.train(train_x, train_y, valid_x, valid_y, mlb)
  File "/home/celso/projects/AttentionXML/deepxml/tree.py", line 200, in train
    self.train_level(self.level - 1, train_x, train_y, valid_x, valid_y)
  File "/home/celso/projects/AttentionXML/deepxml/tree.py", line 86, in train_level
    train_group_y, train_group, valid_group = self.train_level(level - 1, train_x, train_y, valid_x, valid_y)
  File "/home/celso/projects/AttentionXML/deepxml/tree.py", line 132, in train_level
    model = XMLModel(network=FastAttentionRNN, labels_num=labels_num, emb_init=self.emb_init,
  File "/home/celso/projects/AttentionXML/deepxml/models.py", line 145, in __init__
    self.attn_weights = AttentionWeights(labels_num, hidden_size*2, attn_device_ids)
  File "/home/celso/projects/AttentionXML/deepxml/modules.py", line 88, in __init__
    group_size, plus_num = labels_num // len(device_ids), labels_num % len(device_ids)
ZeroDivisionError: integer division or modulo by zero

One error for training FastAttentionXML

args=(self.data_cnf['train']['sparse'], self.data_cnf['train']['labels'], mlb),
KeyError: 'sparse'

there's no information about train sparse train labels in /FastAttentionXML-Wiki-500K.yaml file.
Thanks for any help.

Questions about trainning detail

Hi, thanks for your works.

I just have some question, what's are the batch size and number of epochs for the Amazon-670k and Wiki-500K. For the AttentionXML network without PLT.

Thanks for your answer :)

Troubleshooting with Amazon-670K

I find that if I directly run scripts on EUR-Lex or Amazon-670K, it would raise errors w.r.t segment fault. I doubt the multiprocessing part might be related with this issue. So I slightly change the run_xml.sh like

截屏2020-05-25 下午2 43 17

Finally it works on EUR-Lex data. But still got errors on Amazon-670K when clustering labels. The errors were as follows:

截屏2020-05-25 下午2 44 54

My server has been well set with 2080 Ti Gpus. Is this error merely caused by the GPU memory limits? Or I can adapt the code to fix it further?

and advice?

Regarding execution time on Amazon-670k dataset

I ran the code for the Amazon-670k dataset. I have not made any changes in the code or configuration files. But it's taking more than 10hrs to train and still the training is not complete.

image

My GPU details are given below

gpu_details

Can you confirm the amount of time it takes is fine or not.

Preprocess not working

Hello,

I am trying to run preprocess on provided Wiki10-31K dataset. However, I am facing the following error:

Traceback (most recent call last):
  File "/content/AttentionXML/preprocess.py", line 17, in <module>
    from deepxml.data_utils import *
  File "/content/AttentionXML/deepxml/data_utils.py", line 15, in <module>
    from gensim.models import KeyedVectors
  File "/usr/local/lib/python3.7/dist-packages/gensim/__init__.py", line 5, in <module>
    from gensim import parsing, corpora, matutils, interfaces, models, similarities, summarization, utils  # noqa:F401
  File "/usr/local/lib/python3.7/dist-packages/gensim/corpora/__init__.py", line 6, in <module>
    from .indexedcorpus import IndexedCorpus  # noqa:F401 must appear before the other classes
  File "/usr/local/lib/python3.7/dist-packages/gensim/corpora/indexedcorpus.py", line 15, in <module>
    from gensim import interfaces, utils
  File "/usr/local/lib/python3.7/dist-packages/gensim/interfaces.py", line 19, in <module>
    from gensim import utils, matutils
  File "/usr/local/lib/python3.7/dist-packages/gensim/matutils.py", line 1054, in <module>
    from gensim._matutils import logsumexp, mean_absolute_difference, dirichlet_expectation
  File "__init__.pxd", line 198, in init gensim._matutils
ValueError: numpy.ndarray has the wrong size, try recompiling. Expected 80, got 88

You can reproduce the error in this Colab Notebook.

.npy file

Sir,
During preprocess
how is
--vocab-path data/EUR-Lex/vocab.npy
--emb-path data/EUR-Lex/emb_init.npy \ exist>
when we download the data set it does not have these two files.
please guide

关于论文中PSP@K的评价指标

您好!
请问您能提供PSP@K评价指标的代码吗?想通过这个评价指标跑跑实验,但是不会复现这部分的代码。

License for code

Hello!
We are planning to re-use the code as baselines for our research. Is it possible for you to add a license (i.e. MIT or BSD) for the codebase?
Thanks!

AttentionXML on the Amazon-670k

After the "Finish Clustering" log message, the process seems to be doing nothing. Only RAM is allocated, the processor's cores usage is near zero, and GPU is not allocated yet.
What do I missing?

Time Complexity

Hello?

Could you provide (at a high level) the time complexity of AttentionXML for training and predicting?

You can use $C_L$ for the BiLSTM forward pass cost and abstract other complex terms.
Indeed, it will be excellent if the final formula depends mostly on terms of $N$ (number of texts instances) and $L$ (number o labels)

I appreciate any help you can provide.

CUDA ERROR

hi sir
Screenshot from 2020-09-18 02-16-37
while running code in colab at first time it got run ,
but without any changes when i run again it started showing error
please guide

Some questions about dataset wiki10-31k and wiki-500k

Excellent work!

  1. I know that XMLRepository has these two datasets, but it can not be downloaded now. Can I get the same datasets through your dataset link in your README.md?
  2. Can I use your script preprocess code and GloVe embedding for these two datasets?
  3. I have emailed you, thanks for replying!
    Thank you!!!

Cuda Error : RuntimeError: CUDNN_STATUS_EXECUTION_FAILED

Hi,
I got an error : Cuda Error : RuntimeError: CUDNN_STATUS_EXECUTION_FAILED when trying to train the level 1 and this because of:
loss = self.train_step(train_x, train_y.cuda()) (line 70 models.py)
when i change this line to loss = self.train_step(train_x.cuda(), train_y.cuda()) i still get other issues !!!

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

Even with the required environment already created I am facing multiples errors like:

[I 230221 09:56:33 main:37] Model Name: AttentionXML
[I 230221 09:56:33 main:40] Loading Training and Validation Set
[I 230221 09:56:33 main:52] Number of Labels: 29801
[I 230221 09:56:33 main:53] Size of Training Set: 14748
[I 230221 09:56:33 main:54] Size of Validation Set: 200
[I 230221 09:56:33 main:56] Training
Traceback (most recent call last):
  File "main.py", line 95, in <module>
    main()
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 764, in __call__
    return self.main(*args, **kwargs)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "main.py", line 64, in main
    model.train(train_loader, valid_loader, **model_cnf['train'])
  File "/home/celso/projects/AttentionXML/deepxml/models.py", line 67, in train
    loss = self.train_step(train_x, train_y.cuda())
  File "/home/celso/projects/AttentionXML/deepxml/models.py", line 42, in train_step
    scores = self.model(train_x)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/celso/projects/AttentionXML/deepxml/networks.py", line 42, in forward
    rnn_out = self.lstm(emb_out, lengths)   # N, L, hidden_size * 2
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/celso/projects/AttentionXML/deepxml/modules.py", line 60, in forward
    self.lstm(packed_inputs, (hidden_init, cell_init))[0], batch_first=True)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/celso/projects/venvs/AttentionXML/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 561, in forward
    result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

Are you sure that all tensor operation happens in the same device?

where is PLT compression performed?

I've taken a close look in the code when running experiments on Amazon-670k.

When FastAttentionXML is being trained, it seems that:

  1. we first start partitioning labels using 2-means clustering in deepxml/cluster.py;
  2. we second conduct level-wise training recursively mainly in deepxml/tree.py and models&networks;

I see nowhere the PLT compression operation is involved in the pipeline. Maybe I missed something?

Regarding max_leaf hyperparameter for Amazon-670K dataset

Is there an assumption that max_leaf hyperparameter present in the configuration file will never be set to 1? Because when I try to run with max_leaf = 1 there is an assert statement in cluster.py which fails.

assert sum(len(labels) for labels in labels_list) == labels_f.shape[0]

This is present inside build_tree_by_level method. Also can you explain the above assert statement and why it is necessary?

无法 加载'glove.840B.300d.gensim' model

很高兴能看到这个成果,我在尝试跑你的代码的preprocess.py的时候报错了直接“Aborted”。经过调试,发现可能是没法加载 glove.840B.300d.gensim 这个模型导致的,请问这个现在还有在维护吗?

What is the content of the train_v1.txt file?

What is the content of the train_v1.txt file? How can I get the train_v1.txt file of my own dataset?
It seems that the train_v1.txt file contains a sparse matrix X and Y, and X is the bow feature of each instance. But why the vocab size is different from the vocab.npy?

Error when MultiLabelBinarizer get CSR matrix

I am using scipy==1.11.2 and trying to train on Amazon-670K dataset. I am using the Colab notebook.
Can you please help me to find a way around on this error?

[I 230818 11:55:13 main:37] Model Name: AttentionXML
[I 230818 11:55:13 main:40] Loading Training and Validation Set
[I 230818 11:55:13 main:52] Number of Labels: 34399
[I 230818 11:55:13 main:53] Size of Training Set: 880
[I 230818 11:55:13 main:54] Size of Validation Set: 120
[I 230818 11:55:13 main:56] Training
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/content/drive/MyDrive/xmlc_research/attention_xml/deepxml/optimizers.py:108: UserWarning: This overload of add_ is deprecated:
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1485.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)
[I 230818 11:58:03 models:114] SWA Initializing
Traceback (most recent call last):
  File "/content/drive/MyDrive/xmlc_research/attention_xml/main.py", line 95, in <module>
    main()
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 764, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/content/drive/MyDrive/xmlc_research/attention_xml/main.py", line 64, in main
    model.train(train_loader, valid_loader, **model_cnf['train'])
  File "/content/drive/MyDrive/xmlc_research/attention_xml/deepxml/models.py", line 76, in train
    p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets)
  File "/content/drive/MyDrive/xmlc_research/attention_xml/deepxml/evaluation.py", line 42, in get_precision
    mlb = get_mlb(classes, mlb, targets)
  File "/content/drive/MyDrive/xmlc_research/attention_xml/deepxml/evaluation.py", line 33, in get_mlb
    mlb = MultiLabelBinarizer(range(targets.shape[1]), sparse_output=True)
TypeError: MultiLabelBinarizer.__init__() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given

执行命令过程的报错问题

您好!我在执行“python main.py --data-cnf configure/datasets/EUR-Lex.yaml --model-cnf configure/models/AttentionXML-EUR-Lex.yaml”的时候会显示报错为
/home/jupyter-chanchiuhung/AttentionXML/deepxml/optimizers.py:108: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1050.)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
Traceback (most recent call last):
File "/home/jupyter-chanchiuhung/AttentionXML/main.py", line 95, in
main()
File "/home/jupyter-chanchiuhung/.local/lib/python3.9/site-packages/click/core.py", line 1128, in call
return self.main(*args, **kwargs)
File "/home/jupyter-chanchiuhung/.local/lib/python3.9/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/home/jupyter-chanchiuhung/.local/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/jupyter-chanchiuhung/.local/lib/python3.9/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/home/jupyter-chanchiuhung/AttentionXML/main.py", line 64, in main
model.train(train_loader, valid_loader, **model_cnf['train'])
File "/home/jupyter-chanchiuhung/AttentionXML/deepxml/models.py", line 73, in train
p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets)
File "/home/jupyter-chanchiuhung/AttentionXML/deepxml/evaluation.py", line 42, in get_precision
mlb = get_mlb(classes, mlb, targets)
File "/home/jupyter-chanchiuhung/AttentionXML/deepxml/evaluation.py", line 33, in get_mlb
mlb = MultiLabelBinarizer(range(targets.shape[1]), sparse_output=True)
TypeError: init() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given
请问这是为什么呀?

Transformers as encoder

Hi, will you release the code where BERT is the replacement encoder? It was mentioned in a previous issue. Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.