microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14k stars 1.81k forks source link

my prunning gets killed #5674

Open marziye-A opened 1 year ago

marziye-A commented 1 year ago

@J-shang hi,thanks for your great work. i am trying to prune wav2vec2 model from huggingface and i am using the code below:

from transformers import Wav2Vec2ForCTC
model_id = "******"
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="fas")
from nni.contrib.compression.pruning import L1NormPruner
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
from nni.compression.pytorch.utils.external.huggingface import HuggingfaceModelParser

config_list = [{
    'op_types': ['Linear'],
    'exclude_op_names': ['pre_classifier', 'classifier'],
    'sparse_ratio': 0.5
}]

pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
pruner.unwrap_model()

class HuggingfaceWav2Vec2ForCTCparser(HuggingfaceModelParser):
    TRANSFORMER_PREFIX = r'Wav2Vec2ForCTC\.layers\.Wav2Vec2Attention\.[0-47]+\.'
    QKV = ('attention.q_lin', 'attention.k_lin', 'attention.v_lin')
    QKVO = QKV + ('attention.out_lin',)
    FFN1 = ('ffn.lin1',)
    FFN2 = ('ffn.lin2',)
    ATTENTION = ('attention',)

replacer = TransformersAttentionReplacer(model, HuggingfaceWav2Vec2ForCTCparser)

import librosa
import torch

y, sr = librosa.load('record(3).wav', mono=True)
dummy_input = torch.unsqueeze(torch.from_numpy(y), 0)

ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()

but the pruning gets killed because it rans out of memory.here is part of my output.

.

.0:492: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! ne_15 = size_32 != (mul_31, getitem_27, 80); size_32 = mul_31 = None . . [2023-09-04 13:04:03] Propagate variables for call_function: getitem_121 [2023-09-04 13:04:03] Propagate variables for call_module: wav2vec2_encoder_layers_38_attention_q_proj [2023-09-04 13:04:03] Propagate variables for call_function: mul_152 [2023-09-04 13:04:03] Propagate variables for call_module: wav2vec2_encoder_layers_38_attention_k_proj [2023-09-04 13:04:03] Propagate variables for call_method: view_190 [2023-09-04 13:04:03] Propagate variables for call_method: transpose_207 [2023-09-04 13:04:03] Propagate variables for call_method: contiguous_114 [2023-09-04 13:04:03] Propagate variables for call_module: wav2vec2_encoder_layers_38_attention_v_proj [2023-09-04 13:04:03] Propagate variables for call_method: view_191 [2023-09-04 13:04:03] Propagate variables for call_method: transpose_208 [2023-09-04 13:04:03] Propagate variables for call_method: contiguous_115 [2023-09-04 13:04:03] Propagate variables for call_function: mul_153 [2023-09-04 13:04:03] Propagate variables for call_method: view_192 [2023-09-04 13:04:03] Propagate variables for call_method: transpose_209 [2023-09-04 13:04:03] Propagate variables for call_method: contiguous_116 [2023-09-04 13:04:03] Propagate variables for call_method: view_193 [2023-09-04 13:04:03] Propagate variables for call_method: reshape_114 [2023-09-04 13:04:04] Propagate variables for call_method: reshape_115 Killed my model is 3.86 G but when pruning it uses around 100 G ram. whats the problem?do you know the reason? any help is really appreciated!
gemsanyu commented 4 months ago

I have the same issue, my speedup process gets killed, probably due to oom