EIDOSLAB / simplify

Simplification of pruned models for accelerated inference | SoftwareX https://doi.org/10.1016/j.softx.2021.100907
https://doi.org/10.1016/j.softx.2021.100907
BSD 3-Clause "New" or "Revised" License
35 stars 3 forks source link

Does transformer layer supported?? #6

Closed OriAlpha closed 2 years ago

OriAlpha commented 2 years ago

i am trying to pass transformer model, but encountering issues while passing model. simplified_model = simplify(model, dummy_input,fuse_bn=False)

Am getting AssertionError Does transformer layer is supported at the time??

AndreaBrg commented 2 years ago

HI, could you please provide a minimal reproducible example? Namely:

And if possible the whole error you encounter.

OriAlpha commented 2 years ago

I would be happy to create an example which would help to improve library

OriAlpha commented 2 years ago

Please refer this example.: I am using gelectra model, you can get it from here https://huggingface.co/deepset/gelectra-base/tree/main or you can use a small version of bert

Steps: first you load the model with

tokenizer = AutoTokenizer.from_pretrained("./gelectra")
model = AutoModelForSequenceClassification.from_pretrained("./gelectra")

then you can call pytorch pruning tool:

for name, module in model.named_modules():

    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, t.nn.Embedding):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, t.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.1)

removing original weights in model

for name, module in model.named_modules():

    if isinstance(module, t.nn.Embedding):
        prune.remove(module, 'weight')
        #prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, t.nn.Linear):
        prune.remove(module, 'weight')

which results in zeroing some weights inside the model. Now comes the simplify, when you try to load model it fails if you have any issue, i can assist you

AndreaBrg commented 2 years ago

Ok, thank you. We will get back to you asap.

AndreaBrg commented 2 years ago

@OriAlpha I'm not really familiar with NLP models, could you please give me an example of the dumm_input you are using?

OriAlpha commented 2 years ago

This is where it gets confusing, while passing data. Usually you can pass inputs as below:

input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1

out = model(input_ids)

simplified_model = simplify(model, input_ids) #fails here 
AndreaBrg commented 2 years ago

@OriAlpha Ok, so for what I could find there are a couple of problems with this model:

I currently wouldn't know how to solve these issues but you are welcome to propose a PR in the meantime.

OriAlpha commented 2 years ago

Thanks for looking into this. While testing i had also came across this issue, you could try to pass.

x = torch.LongTensor(1, 512).random_(0, 2^53)
model(x) 

I will look into this torchscript