Git Product home page Git Product logo

leader-pytorch's Introduction

Large Language Model Distilling Medication Recommendation Model

This is the official implementation of the paper "Large Language Model Distilling Medication Recommendation Model".

Running

You can implement our model according to the following steps:

  1. Prepare the LlaMA-7B model: download all files of LlaMA-7B and put them into resources/llama-7b/.

  2. Prepare the data: apply the data from the officail website, and put the unzipped raw data into data/mimic3/raw/ and data/mimic4/, respectively. Then, run the scripts construction.ipynb under data/mimic3/ and data/mimic4/ to preprocess the data. The preprocessed data will be saved under mimic3/handled/ and mimic4/handled/. Besides, the file to convert ATC code to drug name is available from this link, i.e., "WHO ATC-DDD 2021-12-03.csv". Other auxiliay files, such as "drug-DDI.csv" can be otained from the repo of GAMENet and SafeDrug.

  3. Install the necessary packages. Run the command:

    pip install -r requirements.txt
  4. First, train the large language model for medication recommendation via the command:

    bash experiments/llm_cls.bash
  5. Then, you can run the knowledge distillation via the following command:

    bash experiments/mimic3/online_distill.bash
    bash experiments/mimic4/online_distill.bash
  6. For the long running time of distillation, we can save the hidden states from LLM previously. You can run the test on the train file, and the hidden states will be saved in the results automatically vias our llm_cls.bash. Then, put the results file into mimic3/handled/ or mimic4/handled/, then run the KD within two hours!

    bash experiments/mimic3/offline_distill.bash
    bash experiments/mimic4/offline_distill.bash

Citation

If the code and the paper are useful for you, it is appreciable to cite our paper:

@article{liu2024large,
  title={Large Language Model Distilling Medication Recommendation Model},
  author={Liu, Qidong and Wu, Xian and Zhao, Xiangyu and Zhu, Yuanshao and Zhang, Zijian and Tian, Feng and Zheng, Yefeng},
  journal={arXiv preprint arXiv:2402.02803},
  year={2024}
}

Thanks

The code refers to the repo MOELoRA-peft, GAMENet and SafeDrug.

leader-pytorch's People

Contributors

liuqidong07 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

leader-pytorch's Issues

数据问题

您好,非常好的工作!
我想复现您的实验,但是数据需要经过授权才能获得,请问还有其他获得数据的方式么?

Environment problem

Could you please tell me your Python version? I tried to reproduce your code as described in the paper using Python 3.6.5. However, when I run 'pip install -r requirements.txt', I encounter numerous version errors, such as: 'ERROR: Could not find a version that satisfies the requirement accelerate==0.18.0'.

模型测试问题

作者您好,拜读您文章后进行试验复现时出现一些问题,希望您给予帮助。由于内存有限,我们使用zero3策略训练模型后,在测试阶段遇到问题如下:

 train()
 File "main_llm_cls.py", line 78, in train
   model = PeftModelForCLS.from_pretrained(model, model_args.peft_path, is_trainable=False)
 File "/root/autodl-tmp/LEADER/llm/lora_cls.py", line 94, in from_pretrained
   model.load_adapter(model_id, adapter_name, **kwargs)
 File "/root/autodl-tmp/LEADER/llm/lora_cls.py", line 130, in load_adapter
   set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
 File "/root/autodl-tmp/LEADER/llm/lora_cls.py", line 282, in set_peft_model_state_dict
   model.load_state_dict(peft_model_state_dict, strict=False)
 File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
   raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCLS:
       size mismatch for base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).
       size mismatch for base_model.model.model.layers.0.mlp.up_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).
       size mismatch for base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([8, 11008]).
       size mismatch for base_model.model.model.layers.1.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).
       size mismatch for base_model.model.model.layers.1.mlp.up_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).
       size mismatch for base_model.model.model.layers.1.mlp.down_proj.lora_A.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([8, 11008]).
       size mismatch for base_model.model.model.layers.2.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).
       size mismatch for base_model.model.model.layers.2.mlp.up_proj.lora_B.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([11008, 8]).

训练好后的模型文件如图
模型文件
check-point

Topk filtering for Diagnosis or Procedures

Hello, in data processing, you have filtered out diagnoses or procedures except for the top k, but have the drugs corresponding to the filtered diagnoses not been filtered out. Is there a problem with the processed data?

dataset process in construction.ipynb

Hi Liu, Nice work.
May I know how you got these files :
./auxiliary/RXCUI2atc4.csv
./auxiliary/drug-atc.csv
./auxiliary/ndc2RXCUI.txt
./auxiliary/drugbank_drugs_info.csv
./auxiliary/drug-DDI.csv

训练问题

您好,我在复现您代码的时候,出现错误:
RuntimeError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
错误地点是pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
其中我的batch_size:1; logits.shape是[121,1,131], sequence_lengths:tensor[120]
请问是怎么回事呀

运行问题:RuntimeError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!

作者您好,我尝试复现您的代码时,出现以下的错误,Pytorch版本我用的也是1.12.0+cu102,但是torch显示是可用的,

import torch
if __name__ == "__main__":
    print("Cuda support:", torch.cuda.is_available(),":", torch.cuda.device_count(), "devices")
    accelerator = Accelerator()
    print(accelerator.state)

输出:
Cuda support: True : 1 devices
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

具体报错如下,您能帮我解决一下问题吗?
另外requirements.txt中trl==0.7.6需要transformers>=4.31.0,但是在文件中使用的是4.28.1的transformers包,这是否会有问题呢?

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/heyichen/LEADER-pytorch/main_llm_cls.py:216 in │
│ │
│ 213 │
│ 214 if name == "main": │
│ 215 │ │
│ ❱ 216 │ train() │
│ 217 │
│ 218 │
│ 219 │
│ │
│ /home/heyichen/LEADER-pytorch/main_llm_cls.py:60 in train │
│ │
│ 57 def train(): │
│ 58 │ │
│ 59 │ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArg │
│ ❱ 60 │ model_args, data_args, training_args = parser.parse_args_into_dataclasses() │
│ 61 │ device_map = "auto" │
│ 62 │ │
│ 63 │ # load diag, proc, med word2id tokenizer │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/transformers/hf_argparser.py:332 │
│ in parse_args_into_dataclasses │
│ │
│ 329 │ │ │ inputs = {k: v for k, v in vars(namespace).items() if k in keys} │
│ 330 │ │ │ for k in keys: │
│ 331 │ │ │ │ delattr(namespace, k) │
│ ❱ 332 │ │ │ obj = dtype(**inputs) │
│ 333 │ │ │ outputs.append(obj) │
│ 334 │ │ if len(namespace.dict) > 0: │
│ 335 │ │ │ # additional namespace. │
│ in init:115 │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/transformers/training_args.py:1259 │
│ in post_init
│ │
│ 1256 │ │ if ( │
│ 1257 │ │ │ self.framework == "pt" │
│ 1258 │ │ │ and is_torch_available() │
│ ❱ 1259 │ │ │ and (self.device.type != "cuda") │
│ 1260 │ │ │ and (get_xla_device_type(self.device) != "GPU") │
│ 1261 │ │ │ and (self.fp16 or self.fp16_full_eval) │
│ 1262 │ │ ): │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/transformers/training_args.py:1694 │
│ in device │
│ │
│ 1691 │ │ The device used by this process. │
│ 1692 │ │ """ │
│ 1693 │ │ requires_backends(self, ["torch"]) │
│ ❱ 1694 │ │ return self._setup_devices │
│ 1695 │ │
│ 1696 │ @Property
│ 1697 │ def n_gpu(self): │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/transformers/utils/generic.py:54 │
│ in get
│ │
│ 51 │ │ attr = "_cached" + self.fget.name
│ 52 │ │ cached = getattr(obj, attr, None) │
│ 53 │ │ if cached is None: │
│ ❱ 54 │ │ │ cached = self.fget(obj) │
│ 55 │ │ │ setattr(obj, attr, cached) │
│ 56 │ │ return cached │
│ 57 │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/transformers/training_args.py:1679 │
│ in _setup_devices │
│ │
│ 1676 │ │ │ │ if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"): │
│ 1677 │ │ │ │ │ torch.distributed.init_process_group(backend=self.xpu_backend, timeo │
│ 1678 │ │ │ │ else: │
│ ❱ 1679 │ │ │ │ │ torch.distributed.init_process_group(backend="nccl", timeout=self.dd │
│ 1680 │ │ │ device = torch.device("cuda", self.local_rank) │
│ 1681 │ │ │ self._n_gpu = 1 │
│ 1682 │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/torch/distributed/distributed_c10d │
│ .py:602 in init_process_group │
│ │
│ 599 │ │ │ # different systems (e.g. RPC) in case the store is multi-tenant. │
│ 600 │ │ │ store = PrefixStore("default_pg", store) │
│ 601 │ │ │
│ ❱ 602 │ │ default_pg = _new_process_group_helper( │
│ 603 │ │ │ world_size, │
│ 604 │ │ │ rank, │
│ 605 │ │ │ [], │
│ │
│ /home/heyichen/.conda/envs/LEADER/lib/python3.9/site-packages/torch/distributed/distributed_c10d │
│ .py:738 in _new_process_group_helper │
│ │
│ 735 │ │ │ │ pg_options.is_high_priority_stream = False │
│ 736 │ │ │ │ pg_options._timeout = timeout │
│ 737 │ │ │ │
│ ❱ 738 │ │ │ pg = ProcessGroupNCCL(prefix_store, rank, world_size, pg_options) │
│ 739 │ │ │ # In debug mode and if GLOO is available, wrap in a wrapper PG that │
│ 740 │ │ │ # enables enhanced collective checking for debugability. │
│ 741 │ │ │ if get_debug_level() == DebugLevel.DETAIL: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: ProcessGroupNCCL is only supported with GPUs, no GPUs found!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.