casia-iva-lab / flap Goto Github PK
View Code? Open in Web Editor NEW[AAAI 2024] Fluctuation-based Adaptive Structured Pruning for Large Language Models
Home Page: https://arxiv.org/abs/2312.11983
License: Apache License 2.0
[AAAI 2024] Fluctuation-based Adaptive Structured Pruning for Large Language Models
Home Page: https://arxiv.org/abs/2312.11983
License: Apache License 2.0
Traceback (most recent call last):
File "/home/shwu/LABS/FLAP/main.py", line 147, in
main()
File "/home/shwu/LABS/FLAP/main.py", line 141, in main
torch.save(model, f'{args.save_model}/pruned_model.pt')
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 629, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
File "/home/shwu/LABS/FLAP/venv/lib/python3.10/site-packages/torch/serialization.py", line 841, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object 'add_hook_to_module..new_forward'
this occurs after pruning on multiple gpus and saving with torch.save
Are there any parallel efforts ongoing to add support for other models in FLAP?
models such as phi, Gemma, mistral, mistral, etc.
If so, what is the timeline for the same?
check_sparsity return Ratio of the count of non-zero weights to total parameters in the model. How is non-zero sparsity calculated?
Line 43 in 3bb57db
Dear authors,
I would like to express my deep appreciation for the significant contribution your work has made to our community. It has been an invaluable resource for me.
I have encountered a question regarding the calibration set you chose when trying to figure out the code. I hope you can offer some suggestions.
In your implementation, when pruning LLM with FLAP, the calibration set is chosen as WikiText-v2, as shown in this line.
If I understand correctly, comparing the PPL on WikiText-v2 with previous work might be misleading, as previous pruning methods were not trained on WikiText-v2. Could you please confirm this?
Can you specify the calibration set you choose when you report the result in your paper? I can not find this in the paper.
Thanks for your attention, and I hope to hear from you!
@an-yongqi
Dear An
congratulate great job !
Since I am Unauthorized for url: https://huggingface.co/decapoda-research/llama-7b-h/resolve/main/config.json
loading llm model decapoda-research/llama-7b-h
/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download
is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True
.
warnings.warn(
Traceback (most recent call last):
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
response.raise_for_status()
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/requests/models.py", line 1021, in raise_for_status
raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/decapoda-research/llama-7b-h/resolve/main/config.json
So I use baffo32/decapoda-research-llama-7B-hf but there are errors
pruning starts
loading calibdation data
Traceback (most recent call last):
File "/home/delight-gpu/Workspace2/azuryl/FLAP/main.py", line 109, in
main()
File "/home/delight-gpu/Workspace2/azuryl/FLAP/main.py", line 82, in main
prune_flap(args, model, tokenizer, device)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/prune.py", line 294, in prune_flap
dataloader, _ = get_loaders("wikitext2", nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/data.py", line 159, in get_loaders
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
File "/home/delight-gpu/Workspace2/azuryl/FLAP/lib/data.py", line 79, in get_wikitext2
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1767, in load_dataset
builder_instance = load_dataset_builder(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1498, in load_dataset_builder
dataset_module = dataset_module_factory(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1215, in dataset_module_factory
raise e1 from None
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 1192, in dataset_module_factory
return HubDatasetModuleFactoryWithoutScript(
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/load.py", line 765, in get_module
else get_data_patterns_in_dataset_repository(hfh_dataset_info, self.data_dir)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 675, in get_data_patterns_in_dataset_repository
return _get_data_files_patterns(resolver)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 236, in _get_data_files_patterns
data_files = pattern_resolver(pattern)
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/datasets/data_files.py", line 486, in _resolve_single_pattern_in_dataset_repository
glob_iter = [PurePath(filepath) for filepath in fs.glob(PurePath(pattern).as_posix()) if fs.isfile(filepath)]
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/fsspec/spec.py", line 606, in glob
pattern = glob_translate(path + ("/" if ends_with_sep else ""))
File "/home/azuryl/anaconda3/envs/flap/lib/python3.9/site-packages/fsspec/utils.py", line 734, in glob_translate
raise ValueError(
ValueError: Invalid pattern: '**' can only be an entire path component
can you help me
Thank you
Thanks for your inspiring work!
I have a little question on the W_metric for self_attn.o_proj in prune_flap(), there is a square operation, while for W_metric for mlp.down_proj is different.
for name in subset:
if name == 'self_attn.o_proj':
W_metric = metrics[args.metrics](wrapped_layers, subset, name) ** 2 # sqaure is needed
if args.structure == "UL-UM":
W_metric = W_metric.reshape(-1, 128).sum(dim=1)
thresh = torch.sort(W_metric.cuda())[0][int(args.pruning_ratio*layer.self_attn.num_heads)].cpu()
W_mask = (W_metric>=thresh)
attn_mask.append(W_mask)
elif args.structure == "UL-MM":
W_metric = W_metric.reshape(-1, 128).sum(dim=1)
thresh = torch.sort(W_metric.cuda())[0][args.remove_heads // len(layers)].cpu()
W_mask = (W_metric>=thresh)
attn_mask.append(W_mask)
else:
attn_metric_list.append(W_metric.cpu())
attn_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
else:
W_metric = metrics[args.metrics](wrapped_layers, subset, name) # no square
if args.structure == "UL-UM":
thresh = torch.sort(W_metric.cuda())[0][int(W_metric.numel()*args.pruning_ratio)].cpu()
W_mask = (W_metric>=thresh)
mlp_mask.append(W_mask)
elif args.structure == "UL-MM":
thresh = torch.sort(W_metric.cuda())[0][cal_remove_neuron(args, model)].cpu()
W_mask = (W_metric>=thresh)
mlp_mask.append(W_mask)
else:
mlp_metric_list.append(W_metric.cpu())
mlp_baseline_inp_list.append(wrapped_layers[name].baseline_inp.type(torch.half))
wrapped_layers[name].free()
Im really confused. Could you help me out?
Hi, thank you for sharing your impressive work.
I think there might be a need to modify the below line [link].
print(f"model parameter {sum(p.numel() for p in model.parameters()) / 1024 ** 3:.2f}B")
Instead of dividing the number of parameters by 1024 ** 3
to calculate the parameters in billions, it might be more accurate to use 1000 ** 3
. Specifically, when I checked the number of parameters of the following models, the results are:
I would appreciate it if you could share your opinion. Thank you for your time and consideration.
I would like to know how much data was used in the pruning process? Is it just like the example code where nsample=1024, indicating that only 1024 data were used to determine the pruning results?
我尝试使用
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
但是会报错
“Traceback (most recent call last):
File "/home/ubuntu/test_scripts/benchmark_r.py", line 154, in
main()
File "/home/ubuntu/test_scripts/benchmark_r.py", line 63, in main
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True,
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 556, in from_pretrained
return model_class.from_pretrained(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3502, in from_pretrained
) = cls._load_pretrained_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3926, in _load_pretrained_model
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/transformers/modeling_utils.py", line 805, in _load_state_dict_into_meta_model
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
File "/home/ubuntu/miniconda3/envs/xxx/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 348, in set_module_tensor_to_device
raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048, 2785]) in "weight" (which has shape torch.Size([2048, 5504])), this look incorrect.
”
我也尝试了在加载的时候添加参数 ignore_mismatched_sizes=True
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ignore_mismatched_sizes=True)
同样也会报错:
Some weights of QWenLMHeadModel were not initialized from the model checkpoint at /data/xxxx and are newly initialized because the shapes did not match:
请问你们在prune模型之后是怎么去加载的呢。
很着急尝试FLAP,期待您的回复,谢谢。
Facing ModuleNotFoundError: No module named 'models' while inference
Thanks for your nice work. When running the sample script, I'm getting the following error message regarding the wikitext2 loader. Would you kindly check it?
File "/ssd2/bkkim/FLAP/lib/data.py", line 81, in get_wikitext2
traindata = load_dataset('text', data_files='datasets/wikitext/wiki.train.raw', split="train")
FileNotFoundError: Unable to find '/ssd2/bkkim/FLAP/datasets/wikitext/wiki.train.raw'
By the way, I found a workaround for the issue by using the comment lines in the script. Is this an appropriate solution?
# traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
# testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
Thank you for your time.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.