Git Product home page Git Product logo

disc-medllm's Introduction

DISC-MedLLM

Generic badge license

Demo | 技术报告
中文 | EN

DISC-MedLLM 是一个专门针对医疗健康对话式场景而设计的医疗领域大模型,由复旦大学数据智能与社会计算实验室 (Fudan-DISC) 开发并开源。

该项目包含下列开源资源:

您可以通过访问这个链接来试用我们的模型。

概述

DISC-MedLLM 是一个专为医疗健康对话场景而打造的领域大模型,它可以满足您的各种医疗保健需求,包括疾病问诊和治疗方案咨询等,为您提供高质量的健康支持服务。

DISC-MedLLM 有效地对齐了医疗场景下的人类偏好,弥合了通用语言模型输出与真实世界医疗对话之间的差距,这一点在实验结果中有所体现。

得益于我们以目标为导向的策略,以及基于真实医患对话数据和知识图谱,引入LLM in the loop 和 Human in the loop的多元数据构造机制,DISC-MedLLM 有以下几个特点:

  • 可靠丰富的专业知识,我们以医学知识图谱作为信息源,通过采样三元组,并使用通用大模型的语言能力进行对话样本的构造。
  • 多轮对话的问询能力,我们以真实咨询对话纪录作为信息源,使用大模型进行对话重建,构建过程中要求模型完全对齐对话中的医学信息。
  • 对齐人类偏好的回复,病人希望在咨询的过程中获得更丰富的支撑信息和背景知识,但人类医生的回答往往简练;我们通过人工筛选,构建符合人类偏好的高质量的小规模行为微调样本,对齐病人的需求。

data-construction

模型效果演示

疾病问诊

sample1

治疗方案咨询

sample2

数据集

为了训练 DISC-MedLLM ,我们构建了一个高质量的数据集,命名为 DISC-Med-SFT,其中包含了超过47万个衍生于现有的医疗数据集重新构建得到的样本。我们采用了目标导向的策略,通过对于精心选择的几个数据源进行重构来得到SFT数据集。这些数据的作用在于帮助模型学习医疗领域知识,将行为模式与人类偏好对齐,并对齐真实世界在线医疗对话的分布情况。


数据集

数据来源

样本量
重构AI医患对话 MedDialog 400k
cMedQA2 20k
知识图谱问答对 CMeKG 50k
行为偏好数据集 人为筛选 2k
其他 MedMCQA 8k
MOSS-SFT 33k
Alpaca-GPT4-zh 1k

下载

我们总共发布了近47万条训练数据,其中包括重构AI医患对话和知识图谱问答对。您可以访问这个链接下载数据集。


部署

当前版本的 DISC-MedLLM 是基于Baichuan-13B-Base训练得到的。您可以直接从 Hugging Face 上下载我们的模型权重,或者根据下列代码样例中的方式自动获取。

首先,您需要安装项目的依赖环境。

pip install -r requirements.txt

利用Hugging Face的transformers模块来进行推理

>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from transformers.generation.utils import GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Flmc/DISC-MedLLM", use_fast=False, trust_remote_code=True)
>>> model = AutoModelForCausalLM.from_pretrained("Flmc/DISC-MedLLM", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
>>> model.generation_config = GenerationConfig.from_pretrained("Flmc/DISC-MedLLM")
>>> messages = []
>>> messages.append({"role": "user", "content": "我感觉自己颈椎非常不舒服,每天睡醒都会头痛"})
>>> response = model.chat(tokenizer, messages)
>>> print(response)

运行命令行Demo

python cli_demo.py

运行网页版Demo

streamlit run web_demo.py --server.port 8888

此外,由于目前版本的 DISC-MedLLM 是以 Baichuan-13B 作为基座的,您可以参考 Baichuan-13B 项目的介绍来进行 int8 或 int4 量化推理部署。然而需要注意的是,使用模型量化可能会导致性能的下降。

对模型进行微调

您可以使用与我们的数据集结构相同的数据对我们的模型进行微调。我们的训练代码在 Firefly 的基础上进行了修改,使用了不同的数据结构和对话格式。这里我们只提供全参数微调的代码:

deepspeed --num_gpus={num_gpus} ./train/train.py --train_args_file ./train/train_args/sft.json

请您在开始进行模型训练前检查 sft.json 中的设置。


如果您想使用其他训练代码来微调我们的模型,请使用如下对话格式。

<\b><$user_token>content<$assistant_token>content<\s><$user_token>content ...

我们使用的 user_tokenassistant_token 分别为 195 and 196,这和 Baichuan-13B-Chat 是相同的。

模型评测

我们从两个角度评估了模型的性能,包括在单轮QA问题中提供准确答案的能力以及在多轮对话中完成系统性问诊、解决咨询需求的能力。

  • 在单轮对话评测中,我们构建了一个基准测试数据集,其中包含从两个公开医疗数据集中收集的多项选择题,并评估模型回答的准确性。
  • 对于多轮对话评测,我们首先构建了一些高质量的诊疗对话案例,然后让 GPT-3.5 扮演这些案例中的患者角色,并与扮演医生角色的模型进行对话。我们利用 GPT-4 来评估整段每段对话的主动性准确性, 帮助性语言质量

您可以在 eval/ 目录下查看测试数据集、各个模型生成的对话结果以及 GPT-4 提供的打分结果。

单轮QA评测

我们在评测中选用了 MLEC-QA 和考研306(西医综合)的单项选择题。

Few-shot

模型 MLEC-QA 临床 MLEC-QA 中西医结合 MLEC-QA 公共卫生 MLEC-QA 口腔 MLEC-QA 中医 考研306西医综合 平均
GPT-3.5 58.63 45.9 53.51 51.52 43.47 44.81 49.64
Baichuan-13b-Chat 31.25 37.69 28.65 27.27 29.77 24.81 29.91
Huatuo(13B) 31.85 25 32.43 32.95 26.54 24.44 28.87
DISC-MedLLM 44.64 41.42 41.62 38.26 39.48 33.33 39.79

Zero-shot

模型 MLEC-QA 临床 MLEC-QA 中西医结合 MLEC-QA 公共卫生 MLEC-QA 口腔 MLEC-QA 中医 考研306西医综合 平均
GPT-3.5 47.32 33.96 48.11 39.77 38.83 33.33 40.22
Baichuan-13b-Chat 44.05 43.28 39.92 31.06 41.42 32.22 38.66
Huatuo(13B) 27.38 21.64 25.95 25.76 24.92 20.37 24.34
DISC-MedLLM 44.64 37.31 35.68 34.85 41.75 31.11 37.56

多轮对话能力评测

我们的评测基于三个不同的数据集:Chinese Medical Benchmark (CMB-Clin)、Chinese Medical Dialogue Dataset (CMD) 和 Chinese Medical Intent Dataset (CMID),其中 CMB-Clin 模拟了现实世界的问诊过程,而 CMD 和 CMID 则分别着重从科室专业性和用户意图的角度进行评估。

CMB-clin数据集的评测结果:

模型 主动性 准确性 帮助性 语言质量 平均
GPT3.5 4.30 4.53 4.55 5.00 4.60
GPT4 4.15 4.70 4.75 4.96 4.64
Baichuan-13b-Caht 4.30 4.58 4.73 4.95 4.64
BianQue-2 3.97 4.36 4.37 4.81 4.38
Huatuo(13B) 4.40 4.62 4.74 4.96 4.68
DISC-MedLLM 4.64 4.47 4.66 4.99 4.69

CMD数据集的评测结果

cmd

CMID数据集的评测结果

cmid

致谢

本项目基于如下开源项目展开,在此对相关项目和开发人员表示诚挚的感谢:

同样感谢其他限于篇幅未能列举的为本项目提供了重要帮助的工作。

声明

由于语言模型固有的局限性,我们无法保证 DISC-MedLLM 模型所生成的信息的准确性或可靠性。该模型仅为个人和学术团体的研究和测试而设计。我们敦促用户以批判性的眼光对模型输出的任何信息或医疗建议进行评估,并且强烈建议不要盲目信任此类信息结果。我们不对因使用该模型所引发的任何问题、风险或不良后果承担责任。

引用

如果我们的工作有帮助到您的研究,请引用我们:

@misc{bao2023discmedllm,
      title={DISC-MedLLM: Bridging General Large Language Models and Real-World Medical Consultation}, 
      author={Zhijie Bao and Wei Chen and Shengze Xiao and Kuang Ren and Jiaao Wu and Cheng Zhong and Jiajie Peng and Xuanjing Huang and Zhongyu Wei},
      year={2023},
      eprint={2308.14346},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Star History

Star History Chart

disc-medllm's People

Contributors

eltociear avatar f1mc avatar lemuria-wchen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

disc-medllm's Issues

Eval dataset benchmark opensource

作者你好,repo 里面只有 multi-turn 的 eval 数据集,请问 single-turn 中,tech report sample 后的 eval dataset 会开源吗?以及所有的评测脚本和代码是否会开源呢?

Create a demo on Hugging Face?

Hi!

Very cool work! It would be nice to create a demo on the Hugging Face Hub!

It will help to wider reach of your work to the ecosystem.

This is a step-by-step guide explaining the process in case you're interested.

Please let us know if you would be interested and if you have any questions. 😊

偏好对齐数据集

您好,请问stage2里使用的2k条偏好对齐的数据集是怎么构造的呀

模型的配置和分词器

image
这段代码中的模型配置文件和分词器是必须用自己微调后的才行吗?可以直接调用你们训练好的模型吗

运行时报错,看不懂啥原因

2023-09-28 16:14:51.050 Uncaught app exception
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 263, in _get_or_create_cached_value
cached_result = cache.read_result(value_key)
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_resource_api.py", line 500, in read_result
raise CacheKeyNotFoundError()
streamlit.runtime.caching.cache_errors.CacheKeyNotFoundError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 311, in _handle_cache_miss
cached_result = cache.read_result(value_key)
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_resource_api.py", line 500, in read_result
raise CacheKeyNotFoundError()
streamlit.runtime.caching.cache_errors.CacheKeyNotFoundError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/scriptrunner/script_runner.py", line 541, in _run_script
exec(code, module.dict)
File "/Users/renfeng/works/idea/workspace_zz/DISC-MedLLM-main/web_demo.py", line 72, in
main()
File "/Users/renfeng/works/idea/workspace_zz/DISC-MedLLM-main/web_demo.py", line 51, in main
model, tokenizer = init_model()
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 211, in wrapper
return cached_func(*args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 240, in call
return self._get_or_create_cached_value(args, kwargs)
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 266, in _get_or_create_cached_value
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
File "/usr/local/lib/python3.9/site-packages/streamlit/runtime/caching/cache_utils.py", line 320, in _handle_cache_miss
computed_value = self._info.func(*func_args, **func_kwargs)
File "/Users/renfeng/works/idea/workspace_zz/DISC-MedLLM-main/web_demo.py", line 14, in init_model
model = AutoModelForCausalLM.from_pretrained(
File "/usr/local/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 494, in from_pretrained
resolved_config_file = cached_file(
File "/usr/local/lib/python3.9/site-packages/transformers/utils/hub.py", line 429, in cached_file
resolved_file = hf_hub_download(
File "/usr/local/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 110, in _inner_fn
validate_repo_id(arg_value)
File "/usr/local/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 158, in validate_repo_id
raise HFValidationError(
huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/root/output/baichuan13b-sft-after300k-artificial-v2-add/final'. Use repo_type argument if needed.

medical_dialogue数据集问题

大佬,您好!刚刚拜读了您的DISC-MedLLM这篇论文,您不仅介绍了数据集构建,还分享了处理后的数据集。我今天也拿到了medical_dialogue数据集,我想问一下您在处理这部分数据的时候,保留了哪些字段,丢弃了哪些字段。
image

我看到您的数据集,在第一轮对话中,只抽取了【疾病:,病情描述:, 】希望获得的帮助这个字段您保留了吗, 我看到剩下的几个字段包括【患病多久:,过敏史:】都没有保留

希望大佬可以回答我的问题,感谢!!!!

Torch版本问题

想请问一下对应的pytorch是哪个版本?2.0.0吗?requirements.txt没给。谢谢

加载时 tokenizer 出错?

运行下面这句话时报错
tokenizer = AutoTokenizer.from_pretrained("Flmc/DISC-MedLLM", use_fast=False, trust_remote_code=True)

`---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
in
----> 1 tokenizer = AutoTokenizer.from_pretrained("Flmc/DISC-MedLLM", use_fast=False, trust_remote_code=True)

/usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
753 if os.path.isdir(pretrained_model_name_or_path):
754 tokenizer_class.register_for_auto_class()
--> 755 return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
756 elif config_tokenizer_class is not None:
757 tokenizer_class = None

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)
2022 logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
2023
-> 2024 return cls._from_pretrained(
2025 resolved_vocab_files,
2026 pretrained_model_name_or_path,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in _from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs)
2254 # Instantiate the tokenizer.
2255 try:
-> 2256 tokenizer = cls(*init_inputs, **init_kwargs)
2257 except OSError:
2258 raise OSError(

~/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py in init(self, vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token, clean_up_tokenization_spaces, **kwargs)
53 unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
54 pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
---> 55 super().init(
56 bos_token=bos_token,
57 eos_token=eos_token,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils.py in init(self, **kwargs)
365 # 4. If some of the special tokens are not part of the vocab, we add them, at the end.
366 # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following tokenizers
--> 367 self._add_tokens(
368 [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder],
369 special_tokens=True,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils.py in _add_tokens(self, new_tokens, special_tokens)
465 return added_tokens
466 # TODO this is fairly slow to improve!
--> 467 current_vocab = self.get_vocab().copy()
468 new_idx = len(current_vocab) # only call this once, len gives the last index + 1
469 for token in new_tokens:

~/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py in get_vocab(self)
87 def get_vocab(self):
88 """Returns vocab as a dict"""
---> 89 vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
90 vocab.update(self.added_tokens_encoder)
91 return vocab

~/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py in vocab_size(self)
83 def vocab_size(self):
84 """Returns vocab size"""
---> 85 return self.sp_model.get_piece_size()
86
87 def get_vocab(self):

AttributeError: 'BaichuanTokenizer' object has no attribute 'sp_model'
`

运行出错

特别感谢您的无私贡献,我使用baichaunChat进行qlora微调的时候出现了keyerror的问题,具体如下:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Work/disc/train/train.py:185 in <module>                                        │
│                                                                                                  │
│   182                                                                                            │
│   183                                                                                            │
│   184 if __name__ == "__main__":                                                                 │
│ ❱ 185 │   main()                                                                                 │
│   186                                                                                            │
│                                                                                                  │
│ /Work/disc/train/train.py:173 in main                                            │
│                                                                                                  │
│   170 │   trainer = init_components(args, training_args)                                         │
│   171 │   # 开始训练                                                                             │
│   172 │   logger.info("*** starting training ***")                                               │
│ ❱ 173 │   train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkp   │
│   174 │   # 保存最好的checkpoint                                                                 │
│   175 │   final_save_path = join(training_args.output_dir, 'final')                              │
│   176 │   trainer.save_model(final_save_path)  # Saves the tokenizer too                         │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer.py │
│ :1645 in train                                                                                   │
│                                                                                                  │
│   1642 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1643 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1644 │   │   )                                                                                 │
│ ❱ 1645 │   │   return inner_training_loop(                                                       │
│   1646 │   │   │   args=args,                                                                    │
│   1647 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1648 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer.py │
│ :1916 in _inner_training_loop                                                                    │
│                                                                                                  │
│   1913 │   │   │   │   rng_to_sync = True                                                        │
│   1914 │   │   │                                                                                 │
│   1915 │   │   │   step = -1                                                                     │
│ ❱ 1916 │   │   │   for step, inputs in enumerate(epoch_iterator):                                │
│   1917 │   │   │   │   total_batched_samples += 1                                                │
│   1918 │   │   │   │   if rng_to_sync:                                                           │
│   1919 │   │   │   │   │   self._load_rng_state(resume_from_checkpoint)                          │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:633 in __next__                                                                          │
│                                                                                                  │
│    630 │   │   │   if self._sampler_iter is None:                                                │
│    631 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    632 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  633 │   │   │   data = self._next_data()                                                      │
│    634 │   │   │   self._num_yielded += 1                                                        │
│    635 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
│    636 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:1345 in _next_data                                                                       │
│                                                                                                  │
│   1342 │   │   │   │   self._task_info[idx] += (data,)                                           │
│   1343 │   │   │   else:                                                                         │
│   1344 │   │   │   │   del self._task_info[idx]                                                  │
│ ❱ 1345 │   │   │   │   return self._process_data(data)                                           │
│   1346 │                                                                                         │
│   1347 │   def _try_put_index(self):                                                             │
│   1348 │   │   assert self._tasks_outstanding < self._prefetch_factor * self._num_workers        │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/datalo │
│ ader.py:1371 in _process_data                                                                    │
│                                                                                                  │
│   1368 │   │   self._rcvd_idx += 1                                                               │
│   1369 │   │   self._try_put_index()                                                             │
│   1370 │   │   if isinstance(data, ExceptionWrapper):                                            │
│ ❱ 1371 │   │   │   data.reraise()                                                                │
│   1372 │   │   return data                                                                       │
│   1373 │                                                                                         │
│   1374 │   def _mark_worker_as_unavailable(self, worker_id, shutdown=False):                     │
│                                                                                                  │
│ /home/anaconda3/envs/lib/python3.10/site-packages/torch/_utils.py:644 in  │
│ reraise                                                                                          │
│                                                                                                  │
│   641 │   │   │   # If the exception takes multiple arguments, don't try to                      │
│   642 │   │   │   # instantiate since we don't know how to                                       │
│   643 │   │   │   raise RuntimeError(msg) from None                                              │
│ ❱ 644 │   │   raise exception                                                                    │
│   645                                                                                            │
│   646                                                                                            │
│   647 def _get_available_device_type():                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/anaconda3/envs/lib/python3.10/site-packages/transformers/trainer_utils.py", line 706, in __call__
    return self.data_collator(features)
  File "/Work/disc/train/component/collator.py", line 23, in __call__
    target_mask = x['target_mask']
KeyError: 'target_mask'

参数详情:

gpu_vis=4,5
MASTER_PORT=1942
deepspeed  --include localhost:$gpu_vis --master_port $MASTER_PORT disc/train/train.py \
    --deepspeed disc/train/train_args/ds_z3_config.json \
    --output_dir disc/out \
    --model_name_or_path pre_model/Baichuan-13B-Chat-v2 \
    --train_file disc/train/data/DISC-Med-SFT_released.jsonl \
    --overwrite_cache \
    --overwrite_output_dir \
    --num_train_epochs 1.0 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 3 \
    --learning_rate 1e-5 \
    --max_seq_length 1200 \
    --logging_steps 50 \
    --save_steps 2000 \
    --save_total_limit 3 \
    --lr_scheduler_type cosine \
    --warmup_steps 800 \
    --gradient_checkpointing false \
    --disable_tqdm false \
    --optim adamw_hf \
    --seed 42 \
    --fp16 false \
    --bf16 true \
    --report_to tensorboard \
    --dataloader_num_workers 5  \
    --save_strategy steps \
    --weight_decay 0 \
    --max_grad_norm 1.0 \
    --quantization_bit 4

我发现在transformers包中trainer_utils.py文件中的__call__()方法,会将target_mask属性给移除掉,具体如下:

def _remove_columns(self, feature: dict) -> dict:
     if not isinstance(feature, dict):
         return feature
     if not self.message_logged and self.logger and self.model_name:
         ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
         if len(ignored_columns) > 0:
             dset_description = "" if self.description is None else f"in the {self.description} set"
             self.logger.info(
                 f"The following columns {dset_description} don't have a corresponding argument in "
                 f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
                 f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
                 " you can safely ignore this message."
             )
             self.message_logged = True
     return {k: v for k, v in feature.items() if k in self.signature_columns}
def __call__(self, features: List[dict]):
    features = [self._remove_columns(feature) for feature in features]
    return self.data_collator(features)

我对源文件进行了修改,注释掉了features = [self._remove_columns(feature) for feature in features],但是却发生其他错误。因此,我想知到您的transformers和transformers_stream_generator的具体版本是多少,又或者是代码逻辑哪里有什么疏漏,万分感谢!!!(我的版本:transformers== 4.30.1 ,transformers-stream-generator ==0.0.4

我这边运行你们这边提供的demo报错了。'BaichuanTokenizer' object has no attribute 'sp_model'

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("Flmc/DISC-MedLLM", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Flmc/DISC-MedLLM", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("Flmc/DISC-MedLLM")
messages = []
messages.append({"role": "user", "content": "我感觉自己颈椎非常不舒服,每天睡醒都会头痛"})
response = model.chat(tokenizer, messages)
print(response)

/usr/local/lib/python3.11/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Traceback (most recent call last):
File "/hy-tmp/11.py", line 4, in
tokenizer = AutoTokenizer.from_pretrained("Flmc/DISC-MedLLM", use_fast=False, trust_remote_code=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/transformers/models/auto/tokenization_auto.py", line 847, in from_pretrained
return tokenizer_class.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py", line 2089, in from_pretrained
return cls._from_pretrained(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py", line 2311, in _from_pretrained
tokenizer = cls(*init_inputs, **init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py", line 55, in init
super().init(
File "/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils.py", line 367, in init
self._add_tokens(
File "/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils.py", line 467, in _add_tokens
current_vocab = self.get_vocab().copy()
^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py", line 89, in get_vocab
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/Flmc/DISC-MedLLM/c63decba7cb81129fba4157e1d2cc86eca3da44f/tokenization_baichuan.py", line 85, in vocab_size
return self.sp_model.get_piece_size()
^^^^^^^^^^^^^
AttributeError: 'BaichuanTokenizer' object has no attribute 'sp_model'

AttributeError: 'BaichuanTokenizer' object has no attribute 'sp_model' 就是你们这边提供的demo代码。

微调

我能否用类似格式的数据集,和脚本去微调baichuan-13b-base的模型呢?

算法评估里,第二轮对话用户的提问是怎么生成的?

麻烦请教一下,在这个评估文件里eval/dialogues/DISC-MedLLM_cmd.json,第一轮对话的用户提问是从数据集里得到的,第二轮对话的用户提问是怎么生成的?

是你们自己写的么?我看不同的模型,第二轮开始的体问都不一样。

谢谢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.