Git Product home page Git Product logo

flap's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

flap's Issues

AttributeError: Can't pickle local object 'add_hook_to_module.<locals>.new_forward'

Traceback (most recent call last):
File "/home/shwu/LABS/FLAP/main.py", line 147, in
main()
File "/home/shwu/LABS/FLAP/main.py", line 141, in main
torch.save(model, f'{args.save_model}/pruned_model.pt')
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 629, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 841, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object 'add_hook_to_module..new_forward'


this occurs after pruning on multiple gpus and saving with torch.save

support for other models

Are there any parallel efforts ongoing to add support for other models in FLAP?
models such as phi, Gemma, mistral, mistral, etc.

If so, what is the timeline for the same?

Regarding the Choice of Calibration Set

Dear authors,

I would like to express my deep appreciation for the significant contribution your work has made to our community. It has been an invaluable resource for me.

I have encountered a question regarding the calibration set you chose when trying to figure out the code. I hope you can offer some suggestions.

In your implementation, when pruning LLM with FLAP, the calibration set is chosen as WikiText-v2, as shown in this line.

If I understand correctly, comparing the PPL on WikiText-v2 with previous work might be misleading, as previous pruning methods were not trained on WikiText-v2. Could you please confirm this?

Can you specify the calibration set you choose when you report the result in your paper? I can not find this in the paper.

Thanks for your attention, and I hope to hear from you!

can you support baffo32/decapoda-research-llama-7B-hf

@an-yongqi
Dear An
congratulate great job !

Since I am Unauthorized for url: https://huggingface.co/decapoda-research/llama-7b-h/resolve/main/config.json
loading llm model decapoda-research/llama-7b-h
/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Traceback (most recent call last):
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
response.raise_for_status()
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/requests/models.py", line 1021, in raise_for_status
raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/decapoda-research/llama-7b-h/resolve/main/config.json

So I use baffo32/decapoda-research-llama-7B-hf but there are errors

pruning starts
loading calibdation data
Traceback (most recent call last):
File "/home/delight-gpu/Workspace2/azuryl/FLAP/main.py", line 109, in
main()
File "/home/delight-gpu/Workspace2/azuryl/FLAP/main.py", line 82, in main
prune_flap(args, model, tokenizer, device)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/prune.py", line 294, in prune_flap
dataloader, _ = get_loaders("wikitext2", nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/data.py", line 159, in get_loaders
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/data.py", line 79, in get_wikitext2
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1767, in load_dataset
builder_instance = load_dataset_builder(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1498, in load_dataset_builder
dataset_module = dataset_module_factory(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1215, in dataset_module_factory
raise e1 from None
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1192, in dataset_module_factory
return HubDatasetModuleFactoryWithoutScript(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 765, in get_module
else get_data_patterns_in_dataset_repository(hfh_dataset_info, self.data_dir)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 675, in get_data_patterns_in_dataset_repository
return _get_data_files_patterns(resolver)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 236, in _get_data_files_patterns
data_files = pattern_resolver(pattern)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 486, in _resolve_single_pattern_in_dataset_repository
glob_iter = [PurePath(filepath) for filepath in fs.glob(PurePath(pattern).as_posix()) if fs.isfile(filepath)]
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/fsspec/spec.py", line 606, in glob
pattern = glob_translate(path + ("/" if ends_with_sep else ""))
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/fsspec/utils.py", line 734, in glob_translate
raise ValueError(
ValueError: Invalid pattern: '**' can only be an entire path component

can you help me

Thank you

Question on the calculation of W_metric for 'self_attn.o_proj' in prune_flap()

Thanks for your inspiring work!
I have a little question on the W_metric for self_attn.o_proj in prune_flap(), there is a square operation, while for W_metric for mlp.down_proj is different.

        for name in subset:
            if name == 'self_attn.o_proj':
                W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2    # sqaure is needed
                if args.structure == "UL-UM":
                    W_metric = W_metric.reshape(-1, 128).sum(dim=1)
                    thresh = torch.sort(W_metric.cuda())[0][int(args.pruning_ratio*layer.self_attn.num_heads)].cpu()
                    W_mask = (W_metric>=thresh)
                    attn_mask.append(W_mask)
                elif args.structure == "UL-MM":
                    W_metric = W_metric.reshape(-1, 128).sum(dim=1)
                    thresh = torch.sort(W_metric.cuda())[0][args.remove_heads // len(layers)].cpu()
                    W_mask = (W_metric>=thresh)
                    attn_mask.append(W_mask)
                else:
                    attn_metric_list.append(W_metric.cpu())
                attn_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
            else:
                W_metric = metrics[args.metrics](wrapped_layers, subset, name)    # no square
                if args.structure == "UL-UM":
                    thresh = torch.sort(W_metric.cuda())[0][int(W_metric.numel()*args.pruning_ratio)].cpu()
                    W_mask = (W_metric>=thresh)
                    mlp_mask.append(W_mask)
                elif args.structure == "UL-MM":
                    thresh = torch.sort(W_metric.cuda())[0][cal_remove_neuron(args, model)].cpu()
                    W_mask = (W_metric>=thresh)
                    mlp_mask.append(W_mask)
                else:
                    mlp_metric_list.append(W_metric.cpu())
                mlp_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
            wrapped_layers[name].free()

Im really confused. Could you help me out?

Question about the number of parameters

Hi, thank you for sharing your impressive work.

I think there might be a need to modify the below line [link].
print(f"model parameter {sum(p.numel() for p in model.parameters()) / 1024 ** 3:.2f}B")

Instead of dividing the number of parameters by 1024 ** 3 to calculate the parameters in billions, it might be more accurate to use 1000 ** 3. Specifically, when I checked the number of parameters of the following models, the results are:

  • LLaMA-7B: 6738415616 = 6.74B
  • LLM-Pruner’s code (20% pruning) [link]: 5422977024 = 5.42B
  • FLAP’s code (20% pruning) [link]: 5442514944 = 5.44B (not 5.07B)

I would appreciate it if you could share your opinion. Thank you for your time and consideration.

how much data was used in the pruning process?

I would like to know how much data was used in the pruning process? Is it just like the example code where nsample=1024, indicating that only 1024 data were used to determine the pruning results?

pruning 之后使用 无法读取模型

我尝试使用
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
但是会报错
“Traceback (most recent call last):
File "/home/ubuntu/test_scripts/benchmark_r.py", line 154, in
main()
File "/home/ubuntu/test_scripts/benchmark_r.py", line 63, in main
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True,
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 556, in from_pretrained
return model_class.from_pretrained(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3502, in from_pretrained
) = cls._load_pretrained_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3926, in _load_pretrained_model
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 805, in _load_state_dict_into_meta_model
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 348, in set_module_tensor_to_device
raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048, 2785]) in "weight" (which has shape torch.Size([2048, 5504])), this look incorrect.

我也尝试了在加载的时候添加参数 ignore_mismatched_sizes=True
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ignore_mismatched_sizes=True)

同样也会报错:
Some weights of QWenLMHeadModel were not initialized from the model checkpoint at /data/xxxx and are newly initialized because the shapes did not match:

  • transformer.h.10.mlp.c_proj.weight: found shape torch.Size([2048, 2785]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.10.mlp.w1.weight: found shape torch.Size([2785, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.10.mlp.w2.weight: found shape torch.Size([2785, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.11.mlp.c_proj.weight: found shape torch.Size([2048, 2518]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.11.mlp.w1.weight: found shape torch.Size([2518, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.11.mlp.w2.weight: found shape torch.Size([2518, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.12.attn.c_attn.weight: found shape torch.Size([3840, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.12.attn.c_proj.weight: found shape torch.Size([2048, 1280]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.12.mlp.c_proj.weight: found shape torch.Size([2048, 2393]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.12.mlp.w1.weight: found shape torch.Size([2393, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.12.mlp.w2.weight: found shape torch.Size([2393, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.13.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.13.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.13.mlp.c_proj.weight: found shape torch.Size([2048, 3776]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.13.mlp.w1.weight: found shape torch.Size([3776, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.13.mlp.w2.weight: found shape torch.Size([3776, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.14.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.14.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.14.mlp.c_proj.weight: found shape torch.Size([2048, 3594]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.14.mlp.w1.weight: found shape torch.Size([3594, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.14.mlp.w2.weight: found shape torch.Size([3594, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.15.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.15.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.15.mlp.c_proj.weight: found shape torch.Size([2048, 4113]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.15.mlp.w1.weight: found shape torch.Size([4113, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.15.mlp.w2.weight: found shape torch.Size([4113, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.16.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.16.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.17.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.17.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.17.mlp.c_proj.weight: found shape torch.Size([2048, 3263]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.17.mlp.w1.weight: found shape torch.Size([3263, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.17.mlp.w2.weight: found shape torch.Size([3263, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.18.mlp.c_proj.weight: found shape torch.Size([2048, 3861]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.18.mlp.w2.weight: found shape torch.Size([3861, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.18.attn.c_attn.weight: found shape torch.Size([1536, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.18.attn.c_proj.weight: found shape torch.Size([2048, 512]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.18.mlp.w1.weight: found shape torch.Size([3861, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.19.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.19.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.20.attn.c_attn.weight: found shape torch.Size([2688, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.20.attn.c_proj.weight: found shape torch.Size([2048, 896]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.20.mlp.c_proj.weight: found shape torch.Size([2048, 3291]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.20.mlp.w1.weight: found shape torch.Size([3291, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.20.mlp.w2.weight: found shape torch.Size([3291, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.21.attn.c_attn.weight: found shape torch.Size([1536, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.21.attn.c_proj.weight: found shape torch.Size([2048, 512]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.22.attn.c_attn.weight: found shape torch.Size([3072, 2048]) in the checkpoint and torch.Size([6144, 2048]) in the model instantiated
  • transformer.h.22.attn.c_proj.weight: found shape torch.Size([2048, 1024]) in the checkpoint and torch.Size([2048, 2048]) in the model instantiated
  • transformer.h.9.mlp.c_proj.weight: found shape torch.Size([2048, 2630]) in the checkpoint and torch.Size([2048, 5504]) in the model instantiated
  • transformer.h.9.mlp.w1.weight: found shape torch.Size([2630, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
  • transformer.h.9.mlp.w2.weight: found shape torch.Size([2630, 2048]) in the checkpoint and torch.Size([5504, 2048]) in the model instantiated
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    Traceback (most recent call last):
    File "/home/ubuntu/test_scripts/benchmark_r.py", line 152, in
    main()
    File "/home/ubuntu/test_scripts/benchmark_r.py", line 63, in main
    model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ignore_mismatched_sizes=True)
    File "/home/ubuntu/miniconda3/envs/qwen/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 556, in from_pretrained
    return model_class.from_pretrained(
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3558, in from_pretrained
    dispatch_model(model, **device_map_kwargs)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/accelerate/big_modeling.py", line 474, in dispatch_model
    model.to(device)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2556, in to
    return super().to(*args, **kwargs)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1152, in to
    return self._apply(convert)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
    [Previous line repeated 2 more times]
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
    File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1150, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
    NotImplementedError: Cannot copy out of meta tensor; no data!

请问你们在prune模型之后是怎么去加载的呢。
很着急尝试FLAP,期待您的回复,谢谢。

Question about the wikitext2 data loader

Thanks for your nice work. When running the sample script, I'm getting the following error message regarding the wikitext2 loader. Would you kindly check it?

File "/ssd2/bkkim/FLAP/lib/data.py", line 81, in get_wikitext2
    traindata = load_dataset('text', data_files='datasets/wikitext/wiki.train.raw', split="train")
FileNotFoundError: Unable to find '/ssd2/bkkim/FLAP/datasets/wikitext/wiki.train.raw'

By the way, I found a workaround for the issue by using the comment lines in the script. Is this an appropriate solution?

# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
# testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

Thank you for your time.

request for the inference speed test script

image
I saw the evaluation results of the inference speed before and after pruning in the paper, I'm wondering if I can ask for a copy of the test script, I wanna reproduce the result. thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.