NVIDIA-Merlin / models

Merlin Models is a collection of deep learning recommender system model reference implementations
https://nvidia-merlin.github.io/models/main/index.html
Apache License 2.0
246 stars 50 forks source link

[BUG] Error when using class_weight with DLRMModel #1221

Open PaulSteffen-betclic opened 8 months ago

PaulSteffen-betclic commented 8 months ago

Bug description

When I try to use the class_weight arg in .fit() method of DLRMModel, I got the following error:

image

Can someone know why please ? As shown in the following link, it seems to work with earlier versions: https://github.com/NVIDIA-Merlin/publications/blob/2761dade6d725615f0dd2c9491c54f8b397912e0/tutorials/RecSys22tutorial/04-Building-multi-stage-RecSys.ipynb#L731

Thanks

Steps/Code to reproduce bug

import nvtabular as nvt
import tensorflow as tf
import merlin.models.tf as mm

from merlin.models.tf.transforms.negative_sampling import InBatchNegatives
from merlin.dataloader.tensorflow import Loader

processed_train = nvt.Dataset(f"{output_path}/train/*.parquet")
schema = processed_train.schema
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]

batch_size, n_per_positive = 2048, 64
add_negatives = InBatchNegatives(schema, n_per_positive, seed=42, prep_features=True, run_when_testing=True)

train_loader = Loader(processed_train, batch_size=batch_size).map(add_negatives)
valid_loader = Loader(processed_valid, batch_size=batch_size).map(add_negatives)

ranking_model = mm.DLRMModel(
    schema,
    embedding_dim=16,
    bottom_block=mm.MLPBlock([32, 16]),
    top_block=mm.MLPBlock([32, 16, 8]),
    prediction_tasks=mm.BinaryClassificationTask(target_column),
)

ranking_model.compile(optimizer='adam', run_eagerly=False, metrics=[], 
              weighted_metrics=[tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.AUC()]
             )
ranking_model.fit(train_loader, class_weight = {0: 1, 1: n_per_positive}, epochs=2, train_metrics_steps=100) #error when using class_weight

Expected behavior

Success to use class_weight (required with the negative sampling step).

Environment details

rnyak commented 8 months ago

@PaulSteffen-betclic can you share your train.schema file?