Git Product home page Git Product logo

contrastive-htc's Introduction

Incorporating Hierarchy into Text Encoder: a Contrastive Learning Approach for Hierarchical Text Classification

This repository implements a contrastive learning model for hierarchical text classification. This work has been accepted as the long paper "Incorporating Hierarchy into Text Encoder: a Contrastive Learning Approach for Hierarchical Text Classification" in ACL 2022.

Requirements

  • Python >= 3.6
  • torch >= 1.6.0
  • transformers == 4.2.1
  • fairseq == 0.10.0
  • torch-geometric == 1.7.2
  • torch-scatter == 2.0.8
  • torch-sparse == 0.6.12

Preprocess

Please download the original dataset and then use these scripts.

WebOfScience

The original dataset can be acquired in the repository of HDLTex. Preprocess code could refer to the repository of HiAGM and we provide a copy of preprocess code here. Please save the Excel data file Data.xlsx in WebOfScience/Meta-data as Data.txt.

cd ./data/WebOfScience
python preprocess_wos.py
python data_wos.py

NYT

The original dataset can be acquired here.

cd ./data/nyt
python data_nyt.py

RCV1-V2

The preprocess code could refer to the repository of reuters_loader and we provide a copy here. The original dataset can be acquired here by signing an agreement.

cd ./data/rcv1
python preprocess_rcv1.py
python data_rcv1.py

Train

usage: train.py [-h] [--lr LR] [--data {WebOfScience,nyt,rcv1}] [--batch BATCH] [--early-stop EARLY_STOP] [--device DEVICE] --name NAME [--update UPDATE] [--warmup WARMUP] [--contrast CONTRAST] [--graph GRAPH] [--layer LAYER]
                [--multi] [--lamb LAMB] [--thre THRE] [--tau TAU] [--seed SEED] [--wandb]

optional arguments:
  -h, --help            show this help message and exit
  --lr LR               Learning rate.
  --data {WebOfScience,nyt,rcv1}
                        Dataset.
  --batch BATCH         Batch size.
  --early-stop EARLY_STOP
                        Epoch before early stop.
  --device DEVICE		cuda or cpu. Default: cuda
  --name NAME           A name for different runs.
  --update UPDATE       Gradient accumulate steps
  --warmup WARMUP       Warmup steps.
  --contrast CONTRAST   Whether use contrastive model. Default: True
  --graph GRAPH         Whether use graph encoder. Default: True
  --layer LAYER         Layer of Graphormer.
  --multi               Whether the task is multi-label classification. Should keep default since all 
  						datasets are multi-label classifications. Default: True
  --lamb LAMB           lambda
  --thre THRE           Threshold for keeping tokens. Denote as gamma in the paper.
  --tau TAU             Temperature for contrastive model.
  --seed SEED           Random seed.
  --wandb               Use wandb for logging.

Checkpoints are in ./checkpoints/DATA-NAME. Two checkpoints are kept based on macro-F1 and micro-F1 respectively (checkpoint_best_macro.pt, checkpoint_best_micro.pt).

e.g. Train on WebOfScience with batch=12, lambda=0.05, gamma=0.02. Checkpoints will be in checkpoints/WebOfScience-test/.

python train.py --name test --batch 12 --data WebOfScience --lamb 0.05 --thre 0.02

Reproducibility

Contrastive learning is sensitive to hyper-parameters. We report results with fixed random seed but we observe higher results with unfixed seed.

  • The results reported in the main table can be observed with following settings under seed=3.
WOS: lambda 0.05 thre 0.02
NYT: lambda 0.3 thre 0.002
RCV1: lambda 0.3 thre 0.001

We experiment on GeForce RTX 3090 (24G) with CUDA version $11.2$.

  • The following settings can achieve higher results with unfixed seed (which we reported in the paper) .
WOS: lambda 0.1 thre 0.02
NYT: lambda 0.3 thre 0.005
RCV1: lambda 0.3 thre 0.005
  • We also find that a higher tau (e.g. tau=2) is beneficial but we keep it to $1$ for simplicity.

Test

usage: test.py [-h] [--device DEVICE] [--batch BATCH] --name NAME [--extra {_macro,_micro}]

optional arguments:
  -h, --help            show this help message and exit
  --device DEVICE
  --batch BATCH         Batch size.
  --name NAME           Name of checkpoint. Commonly as DATA-NAME.
  --extra {_macro,_micro}
                        An extra string in the name of checkpoint. Default: _macro

Use --extra _macro or --extra _micro to choose from using checkpoint_best_macro.pt orcheckpoint_best_micro.pt respectively.

e.g. Test on previous example.

python test.py --name WebOfScience-test

Citation

@inproceedings{wang-etal-2022-incorporating,
    title = "Incorporating Hierarchy into Text Encoder: a Contrastive Learning Approach for Hierarchical Text Classification",
    author = "Wang, Zihan  and
      Wang, Peiyi  and
      Huang, Lianzhe  and
      Sun, Xin  and
      Wang, Houfeng",
    booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = may,
    year = "2022",
    address = "Dublin, Ireland",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.acl-long.491",
    pages = "7109--7119",
}

contrastive-htc's People

Contributors

wzh9969 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

contrastive-htc's Issues

请问对比学习模块为什么要添加一个relu层呢?

如题。是用bert的CLS输出作为文本表示后,为何在对比学习模块又添加了一个非线性层RELU呢?这个的作用是什么呢?直接使用bert的CLS输出作为文本表示进行对比损失会产生什么影响呢?

python depedencies

Hey, i would like to reproduce your results, could you provide a proper requirements.txt ? It's been a a day i've tried to setup a correct conda environment but i always have conflicts.

关于运行data_wos.py报错的问题

作者您好,我按照requirement.txt文件配置好了环境,但是运行data_wos.py时,出现了一个错误(如下图)。百度之后还是不知道该怎么解决,请问您知道解决办法吗?
image

Problem in running train.py

TypeError: 'NoneType' object is not subscriptable appears when I try to run the script train.py. can you suggest me possible reasons for this issue. The error is inside the first line of the function getitem(self, item) of the class BertDataset(Dataset)

Screenshot from 2022-06-06 20-26-39

loss下降,评价指标为0

如题,docker内安装依赖环境,WOS数据集
python train.py --name test --batch 12 --data WebOfScience --lamb 0.1 --thre 0.02
随着训练过程,loss从1.8逐渐下降到0.8,但是macro和micro一直徘徊在0周围,请问大佬遇到过这种情况吗?

段错误

在ubuntu10.04环境下,我用的是python=3.7.13,其他的python包都是用的requirements的版本。但总是报段错误,核心转存。什么原因?急!!!

How to generate slot.pt?

Thank you for providing code.
Thanks for providing the code, but I encountered the following error when training

image

how to train on WOS-5736

looks like the code trains on the whole excel (WOS-49685), is there a utility or standard way I should approach this? I am thinking of creating some sort of logic to filter out entries from Excel by using a mapping from WOS-5736 provided in the dataset. Although it doesn't list indexes, it has data points split into multiple files.

关于rcv1数据处理的问题

您好,代码中关于rcv1数据的处理有下面的文件找不到:
with open('lyrl2004_tokens_train.dat', 'r') as f: for line in f.readlines(): if line.startswith('.I'): train_ids.append(int(line[3:-1]))
请问上面代码中的lyrl2004_tokens_train.dat文件如何获得?

Are these errors expected during training?

  • Some weights of the model checkpoint at bert-base-uncased were not used when initializing ContrastModel:
  • Some weights of ContrastModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized:
(venv) ubuntu@ip-172-31-46-151:~/contrastive-htc((HEAD detached at b5bbe1a))$ python train.py --name test4 --batch 8 --data WebOfScience --lamb 0.05 --thre 0.02 --early-stop 4 --tau=1.5

Namespace(lr=3e-05, data='WebOfScience', batch=8, early_stop=4, device='cuda', name='test4', update=1, warmup=2000, contrast=1, graph=1, layer=1, multi=True, lamb=0.05, thre=0.02, tau=1.5, seed=3, wandb=False)

2023-03-04 21:54:01 | INFO | fairseq.data.data_utils | loaded 46985 examples from: data/WebOfScience/tok
2023-03-04 21:54:01 | INFO | fairseq.data.data_utils | loaded 46985 examples from: data/WebOfScience/Y

Some weights of the model checkpoint at bert-base-uncased were not used when initializing ContrastModel: ['bert.pooler.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'bert.pooler.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing ContrastModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ContrastModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

Some weights of ContrastModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['graph_encoder.edge_list', 'graph_encoder.hir_layers.0.cross_attn.v_proj.bias', 'contrastive_lossfct.transform.1.weight', 'graph_encoder.label_name', 'graph_encoder.distance_mat', 'graph_encoder.distance_embedding.weight', 'token_classifier.out_proj.weight', 'graph_encoder.hir_layers.0.hir_attn.self.out_proj.bias', 'graph_encoder.hir_layers.0.output_layer.0.bias', 'graph_encoder.hir_layers.0.hir_attn.self.v_proj.weight', 'graph_encoder.hir_layers.0.hir_attn.self.q_proj.bias', 'graph_encoder.hir_layers.0.output_layer.2.bias', 'graph_encoder.hir_layers.0.classifier.weight', 'graph_encoder.hir_layers.0.output_layer_norm.bias', 'graph_encoder.hir_layers.0.hir_attn.self.q_proj.weight', 'contrastive_lossfct.transform.4.bias', 'classifier.bias', 'token_classifier.out_proj.bias', 'graph_encoder.hir_layers.0.cross_attn.q_proj.bias', 'graph_encoder.hir_layers.0.cross_attn.v_proj.weight', 'contrastive_lossfct.transform.1.bias', 'graph_encoder.hir_layers.0.cross_layer_norm.bias', 'graph_encoder.edge_embedding.weight', 'graph_encoder.hir_layers.0.hir_attn.self.out_proj.weight', 'graph_encoder.hir_layers.0.hir_attn.self.k_proj.weight', 'graph_encoder.hir_layers.0.cross_attn.out_proj.weight', 'graph_encoder.hir_layers.0.hir_attn.self.k_proj.bias', 'graph_encoder.hir_layers.0.cross_attn.k_proj.bias', 'graph_encoder.hir_layers.0.hir_attn.layer_norm.weight', 'graph_encoder.hir_layers.0.output_layer.2.weight', 'graph_encoder.label_id', 'graph_encoder.hir_layers.0.classifier.bias', 'graph_encoder.id_embedding.weight', 'graph_encoder.hir_layers.0.hir_attn.layer_norm.bias', 'graph_encoder.hir_layers.0.cross_attn.out_proj.bias', 'graph_encoder.hir_layers.0.hir_attn.self.v_proj.bias', 'graph_encoder.hir_layers.0.cross_attn.q_proj.weight', 'graph_encoder.hir_layers.0.cross_attn.k_proj.weight', 'token_classifier.dense.bias', 'classifier.weight', 'contrastive_lossfct.transform.4.weight', 'graph_encoder.edge_mat', 'graph_encoder.hir_layers.0.output_layer_norm.weight', 'token_classifier.dense.weight', 'graph_encoder.hir_layers.0.output_layer.0.weight', 'graph_encoder.hir_layers.0.cross_layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

datas

hello,could you please give me the datasets of NTY and RCV-1?I can't get the datasets from the website. I promise that I will only use these datasets for researching.

rcv1复现结果相差较大

大佬好,请问您的参数设置是否和论文一样?我在复现rcv1数据集的结果时,相差了1.5点,请问是否有参数设置不一样的地方?

运行preprocess_wos.py出错

运行preprocess_wos.py出错FileNotFoundError: [Errno 2] No such file or directory: 'Meta-data/Data.txt', 请问这是为什么?

BartAttention中的attention_mask好像没做转换

BartAttetion 计算时,attention_mask未转换成-inf,相当于所有位置都加了1,实际transformer库中bert模型会将0,1的mask矩阵转换成torch中的min,代码中好像没看到这一步,实际跑的时候所有token_probs都是一致的

evaluate() function

Thank you for producing a great paper and making the codebase public; it is very inspiring.

I had a question regarding the evaluate() function from eval.py file. The paper focuses on working with datasets with hierarchy, but while evaluating, seems like those hierarchy are not taken care of? Meaning at the end there is a flat-classification and probability scores (dimension = no. of total unique classes) is generated; now afterwards we see that the evaluate() function prepares right_count_list after thresholding these probabilities without taking into consideration the hierarchy? Also, why are the final scores taken purely based on list counts, and not comparing pred and truth element-wise?

I am also worries about scenarios where threshold is chosen to be very small (~0) and then right_count_list would exactly match predicted_count_list.

I believe one approach of handling this would be performing evaluation level-wise; meaning first pick class with highest probability in level-0, then among children of that class, pick class with highest probability in level-1 and so on.

Please share your thoughts, and feedback.

训练时长

您好,感觉网络结构还蛮复杂的,不知道WOS数据集,训练到论文中的表现需要多少个epoch呀,目前我用V100, 33min一轮,感觉时间好长

thank you for your help

          thank you for your help

| |
@.***
|
|
@.***
|

---- Replied Message ----
| From | Wang @.> |
| Date | 05/25/2023 17:55 |
| To | @.
> |
| Cc | Zhanwang @.>@.> |
| Subject | Re: [wzh9969/contrastive-htc] datas (Issue #24) |

Links in readme are what you need. You can see agreements for RCV1 and the download link for NYT at the bottom. Please read those pages carefully.


Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: @.***>

Originally posted by @zwweilai in #24 (comment)

WOS - UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 15: invalid start byte

Steps

  1. Downloaded data from https://data.mendeley.com/datasets/9rw3vkcfy4/6 as mentioned in https://github.com/kk7nc/HDLTex#datasets-for-hdltex.
  2. Copied 'Data.xlsx' to 'Meta-Data/Data.txt' (just renamed the extension)
  3. Ran py preprocess_wos.py after installing necessary libraries, after which I got the following traceback
Traceback (most recent call last):
  File "/Users/ayush/workdir/personal/masters-research/contrastive-htc/data/WebOfScience/preprocess_wos.py", line 174, in <module>
    get_data_from_meta()
  File "/Users/ayush/workdir/personal/masters-research/contrastive-htc/data/WebOfScience/preprocess_wos.py", line 69, in get_data_from_meta
    origin_txt = f.readlines()
  File "/opt/homebrew/Cellar/[email protected]/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 15: invalid start byte

The question about label information passed to GraphEncoder

On line 369 of the file graph.py, why is the label tensor used to multiply the output of the graph after applying the gumbel_softmax function? This may potentially leak label information to the contrast_mask tensor, which could aid in downstream classification tasks?

two wired numbers in graph model

edge_mat = torch.zeros(len(inverse_label_list), len(inverse_label_list), 15, dtype=torch.long)
self.distance_embedding = nn.Embedding(20, 1, 0)

Question: why do you set 15 as the third dimension size of edge matrix and 20 as the first dimension size of distance embedding? Do they differ from other numbers?

请问原始样本和正样本的bert是否参数完全一致?

如题。如果参数完全一致的话,将层次结构注入到bert编码器的过程,通过对比学习会微调反向传播更新bert参数?那么更新的是哪个bert呢?正样本和原始样本bert编码的顺序是否对结果有影响?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.