adap / flower

Flower: A Friendly Federated Learning Framework
https://flower.ai
Apache License 2.0
4.51k stars 792 forks source link

TFRS Model with simulation #894

Open tylershumaker opened 2 years ago

tylershumaker commented 2 years ago

I am trying to combine Tensorflow recommenders (TFRS) quickstart example with the Flower quickstart simulation notebook in Colab.

I can't get past this error:

ValueError: Weights for model sequential have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

model.build() is not required to train original tfrs model.
Link to Colab here

class FlwrClient(fl.client.NumPyClient):
    def __init__(self, model, train_ds) -> None:
        super().__init__()
        self.model = model
        self.train_ds = train_ds

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.train_ds,
                        epochs=3,
                        verbose=0)
        return self.model.get_weights(), len(list(self.train_ds)), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss = self.model.compute_loss(self.train_ds)
        return loss, len(list(self.train_ds)), {}

@tf.autograph.experimental.do_not_convert
def client_fn(cid: str) -> fl.client.Client:

    # ratings data
    rating, rating_info = tfds.load('movielens/latest-small-ratings', 
                                    split='train', with_info=True,)
    # features of all the movies
    movies = tfds.load('movielens/latest-small-movies', split='train')

    rating_map = lambda x:{'movie_title':x['movie_title'], 
                           'user_id':x['user_id']}
    rating = rating.map(rating_map)
    movies_map = lambda x: x['movie_title']
    movies = movies.map(movies_map)

    user_id_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(mask_token=None)
    rating_map = lambda x: x['user_id']
    user_id_vocabulary.adapt(rating.map(rating_map))

    movies_title_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(mask_token=None)
    movies_title_vocabulary.adapt(movies)

    rating = rating.cache()
    rating_size = rating_info.splits['train'].num_examples
    rating = rating.shuffle(rating_size)

    train_split = 0.6
    train_part = int(train_split * rating_size)
    rating = rating.take(train_part)
    rating = rating.batch(4096)
    rating = rating.prefetch(tf.data.AUTOTUNE)

    users_model = tf.keras.Sequential([user_id_vocabulary,
                                      tf.keras.layers.Embedding(user_id_vocabulary.vocabulary_size(),64)])
    movie_model = tf.keras.Sequential([movies_title_vocabulary,
                                      tf.keras.layers.Embedding(movies_title_vocabulary.vocabulary_size(),64)])

    task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
        movies.batch(128).map(movie_model)))

    model = MovieLensModel(users_model,movie_model,task)
    model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5),
                  loss='categorical_crossentropy')
    # model.build(input_shape=(None, rating_size)) 

    # Create and return client
    return FlwrClient(model, rating)
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    client_resources={"num_cpus": 2},
    num_rounds=2,
    strategy=fl.server.strategy.FedAvg(
        fraction_fit=0.1,
        fraction_eval=0.1,
        min_fit_clients=10,
        min_eval_clients=10,
        min_available_clients=NUM_CLIENTS,
    ),
)
danieljanes commented 2 years ago

Hi @tylershumaker , thanks for reaching out. The error you're seeing seems to be a TensorFlow error. Could it be related to the commented-out line model.build(...)? Maybe it's caused by the fact that when you create the MovieLensModel and pass it to the FlwrClient, the first thing that happens in get_parameters is model.get_weights (or in fit the call to model.set_weights)? The error message suggests that the model should be called on inputs first (or use build()) to create weights.