google-research / jaxpruner

Apache License 2.0
206 stars 14 forks source link

Request for Optimal Brain Surgeon -- SparseGPT #9

Open opooladz opened 2 months ago

opooladz commented 2 months ago

Hi and thanks for the amazing repo.

I have a bit of tall request. SparseGPT uses a per-layer optimal brain surgeon approach to pruning. Here is the pytorch code.

Having this in jax would really help push the boundaries of what we can do.

Thank you, Omead

evcu commented 2 months ago

Hi, Elias did implement gptq in jax and I plan to open-source it sometime soon. I'll keep this open for now and update the issue when I have the chance

opooladz commented 2 months ago

gptq is for quantization where as layerwise OBS/SparseGPT is for pruning unless im missing something. I know Elias is on both works so maybe he has SparseGPT in jax as well which would be nice.

There is also Hessian Aware Pruning, https://openaccess.thecvf.com/content/WACV2022/papers/Yu_Hessian-Aware_Pruning_and_Optimal_Neural_Implant_WACV_2022_paper.pdf

which would be nice to have for LLMs. Also another nice feature would be to allow for the user to say what % of weights they want pruned and just directly prune. This would mean taking care of sharding etc but it would probs make it easier on the user.