tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.5k stars 323 forks source link

Behaviour of stripped models and sequence masking #403

Open captainproton1971 opened 4 years ago

captainproton1971 commented 4 years ago

Describe the bug Training a sparse model including an Masking Layer and an LSTM, and then stripping model with strip_pruning() produces models that handle the masking differently. This means the models trained with pruning are not useable for inference after stripping.

System information

Python version:

Describe the expected behavior Models with a prune_low_magnitude LSTM layer should generate same output as the same model stripped model.

Describe the current behavior A model in which a prune_low_magnitude LSTM layer is fed by a masking layer appears to handle the masking differently than its stripped version. Moreover, it's not a simple ignoring of the sequence mask.

Code to reproduce the issue

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Bidirectional, LSTM
from tensorflow_model_optimization.sparsity import keras as sparsity

# Create some fake masked sequences (I'm sure there's a cleaner way to do this)
pad_val = 7
samples = 10
timesteps = 9
features = 6

data = np.full((samples, timesteps, features),  pad_val, dtype=np.float32)
rng = np.random.default_rng()
num_len = rng.integers(low=3, high=timesteps, size=(samples,1))

for i in range(samples):
    len = int(num_len[i])
    x = np.random.uniform(low=-1, high=1, size=(len))
    for j in range(len):
        data[i,j,:] = x[j]

# Create a model incorporating a masking layer and a pruneable LSTM.
pruning_model = keras.Sequential(
    [keras.layers.Masking(mask_value=pad_val, input_shape=(timesteps, features)),
     sparsity.prune_low_magnitude(keras.layers.LSTM(2))])

# Get the output of the model, as well as just the LSTM (skipping the masking)
x1_a = pruning_model(data).numpy()  #LSTM output with masked input
x1_b = pruning_model.layers[1](data).numpy() #LSTM output skipping masked input

# Now strip the model, and check weights are the same
depruned_model = sparsity.strip_pruning(pruning_model)

for i, w in enumerate(depruned_model.layers[1].weights):
    assert np.allclose(w.numpy(), pruning_model.layers[1].weights[i].numpy())

x2_a = depruned_model(data).numpy()
x2_b = depruned_model.layers[1](data).numpy()

Check behaviour of outputs

Including the masking

print(np.allclose(x1_a,x2_a)) #Returns False: the outputs of the two models are not compatible if masking included
print(np.allclose(x1b,x2b)) #Returns True: Skipping masking yields the same output from the two models
print(np.allclose(x1_a,x1_b) #Returns False: The prunable model isn't just ignoring the masking.

# Check that the masks are the same:
m1 = pruning_model.layers[0](data)._keras_mask
m2 = depruned_model.layers[0](data)._keras_mask
np.all(m1==m2) # Returns True: the same masks are being passed to the LSTM layer.

A quick inspection shows that x1_a, x1_b, and x2_a are very different from each other. I've also confirmed that the layer configurations are the same in the two models.

Additional context I didn't see anything in the documentation re: behaviour with masked inputs to LSTM layers but the current behaviour (handle it differently than either the 'usual' or ignoring the mask complexly) seems counter-intuitive.

I found this problem after training and pruning a masked LSTM model, stripping it and finding that the output was not consistent with the pruned model.

Thank you in advance for any help you can offer.

captainproton1971 commented 4 years ago

Hi, just wondering if anyone has started to look at this?

alanchiao commented 4 years ago

Hi @captainproton1971. Unfortunately haven't had the time to start this and don't personally expect to for a while.

@liyunlu0618

Unassigning self since there should be an assignment only if there is active work on it.

teijeong commented 3 years ago

Hi @captainproton1971, sorry for really late response.

Just want to check if this still bugs you before we start triage.

captainproton1971 commented 3 years ago

Hi, thanks @teijeong . I'll check this weekend to see if it's still causing problems and post a reply.