tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 319 forks source link

CQAT fails to preserve clusters on ResNet-50 #1056

Open funkyyyyyy opened 1 year ago

funkyyyyyy commented 1 year ago

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug CQAT does not preserve clusters. Training ResNet-50 with CIFAR-100

System information

TensorFlow version (installed from source or binary): TensorFlow 2.5

TensorFlow Model Optimization version (installed from source or binary): 0.7.3

Python version: 3.7.13

Describe the expected behavior Model weight clusters are preserved after cluster preserving quantization aware training

Describe the current behavior Model weight clusters are not preserved for some of the kernels after cluster preserving quantization aware training

Code to reproduce the issue Provide a reproducible code that is the bare minimum necessary to generate the problem.

import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.keras import datasets
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import zipfile

(ds_train, ds_test), ds_info = tfds.load(
    'cifar100',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

batch_size = 32

num_val = int(ds_info.splits['train'].num_examples * 0.1)
num_train = ds_info.splits['train'].num_examples - num_val

ds_val = ds_train.take(num_val)
ds_train = ds_train.skip(num_val)

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.image.convert_image_dtype(image, tf.float32), label

def augment_img(image, label):
    image = tf.image.random_flip_left_right(image)
    rand = tf.random.uniform([2], minval=0, maxval=1)
    if rand[0] > 0.5:
        image = tf.image.random_brightness(image, 0.1)
    if rand[1] > 0.5:
        crop_factor = 0.9
        image = tf.image.random_crop(image, (int(32 * crop_factor), int(32 * crop_factor) ,3))
        image = tf.image.resize(image, (32, 32))

    return image, label

def resize_img(image, label):
    image = tf.image.resize(image, (224, 224))

    return image, label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.map(augment_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.shuffle(num_train)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_val = ds_val.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.batch(batch_size)
ds_val = ds_val.cache()
ds_val = ds_val.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

inputs = tf.keras.Input(shape=(224, 224, 3))
base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, 
                                                     weights='imagenet',
                                                     input_tensor=inputs)

x = tf.keras.layers.Flatten()(base_model.output)
outputs = tf.keras.layers.Dense(100)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

base_model.trainable = True

model.summary()

initial_lr = 0.001

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=initial_lr),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

initial_epochs = 10

history = model.fit(ds_train,
                    epochs=initial_epochs,
                    validation_data=ds_val,
                    callbacks=[tf.keras.callbacks.LearningRateScheduler(
                                    tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate = initial_lr, 
                                                                              decay_steps = initial_epochs, 
                                                                              alpha = 0.027), # equivalent to dropping learning rate 3 times by factor of 0.3
                                    verbose=1)
                              ]
                   )

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster

cluster_weights = cluster.cluster_weights

CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16, 
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  # 'cluster_per_channel': True 
}

model_for_clustering = cluster_weights(model, **clustering_params)

lr = 0.027 * initial_lr # final lr of initial training above

model_for_clustering.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])

model_for_clustering.summary()

model_for_clustering.fit(ds_train,
                         epochs=initial_epochs, 
                         validation_data=ds_val)

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(model_for_clustering)

def print_model_weight_clusters(model, pre_layer_name=""):
    for layer in model.layers:
        if hasattr(layer, 'layers'): 
            if pre_layer_name == "":
                pre_layer_name = layer.name
            else:
                pre_layer_name = '{}/{}'.format(pre_layer_name, layer.name) 
            print_model_weight_clusters(layer, pre_layer_name)
            pre_layer_name = ""
            continue
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                if pre_layer_name == "":
                    print(f"{layer.name}/{weight.name}: {unique_count} clusters ")
                else:
                    print(f"{pre_layer_name}/{layer.name}/{weight.name}: {unique_count} clusters ")

print_model_weight_clusters(stripped_clustered_model)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=False))

lr = 0.027 * initial_lr 

cqat_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

cqat_model.summary()

subset_size = 2000

subset_ds_train = ds_train.unbatch()\
                          .shuffle(num_train)\
                          .take(subset_size)\
                          .shuffle(subset_size)\
                          .batch(batch_size)\
                          .prefetch(tf.data.AUTOTUNE)

history = cqat_model.fit(subset_ds_train,
                         epochs=initial_epochs,
                         validation_data=ds_val)

print_model_weight_clusters(cqat_model)

Screenshots If applicable, add screenshots to help explain your problem.

Additional context Add any other context about the problem here.

Xhark commented 1 year ago

It might be a bug due to it's experimental. @MatteoArm Do you have any idea? Thanks!