Open opooladz opened 2 months ago
Hi, Elias did implement gptq in jax and I plan to open-source it sometime soon. I'll keep this open for now and update the issue when I have the chance
gptq is for quantization where as layerwise OBS/SparseGPT is for pruning unless im missing something. I know Elias is on both works so maybe he has SparseGPT in jax as well which would be nice.
There is also Hessian Aware Pruning, https://openaccess.thecvf.com/content/WACV2022/papers/Yu_Hessian-Aware_Pruning_and_Optimal_Neural_Implant_WACV_2022_paper.pdf
which would be nice to have for LLMs. Also another nice feature would be to allow for the user to say what % of weights they want pruned and just directly prune. This would mean taking care of sharding etc but it would probs make it easier on the user.
Hi and thanks for the amazing repo.
I have a bit of tall request. SparseGPT uses a per-layer optimal brain surgeon approach to pruning. Here is the pytorch code.
Having this in jax would really help push the boundaries of what we can do.
Thank you, Omead