tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
663 stars 110 forks source link

Training new tensorflow random forest gets "The model is already trained" #165

Closed josseossa closed 1 year ago

josseossa commented 1 year ago

Hi, I'm just running the following commands to train a model:

# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="y_class")
val_ds   = tfdf.keras.pd_dataframe_to_tf_dataset(val_df, label="y_class")
test_ds  = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="y_class")

# Train the model
model = tfdf.keras.RandomForestModel(num_trees=n_estimators,
                                     max_depth=16)
model.fit(train_ds, 
          class_weight=class_weights,
          validation_data=val_ds)

and I'm getting the following error: any ideas if there is a cache folder to clean?

Reading training dataset...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [21], in <cell line: 23>()
     20 # Train the model
     21 model = tfdf.keras.RandomForestModel(num_trees=n_estimators,
     22                                      max_depth=16)
---> 23 model.fit(train_ds, 
     24           class_weight=class_weights,
     25           validation_data=val_ds)
     27 # Look at the model.
     28 model.summary()

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:1153, in CoreModel.fit(self, x, y, callbacks, verbose, validation_steps, validation_data, sample_weight, steps_per_epoch, class_weight, **kwargs)
   1150 # Reset the training status.
   1151 self._is_trained.assign(False)
-> 1153 return self._fit_implementation(
   1154     x=x,
   1155     y=y,
   1156     callbacks=callbacks,
   1157     verbose=verbose,
   1158     validation_steps=validation_steps,
   1159     validation_data=validation_data,
   1160     sample_weight=sample_weight,
   1161     steps_per_epoch=steps_per_epoch,
   1162     class_weight=class_weight)

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:1294, in CoreModel._fit_implementation(self, x, y, verbose, callbacks, sample_weight, validation_data, validation_steps, steps_per_epoch, class_weight)
   1287 iterator = iter(data_handler._dataset)  # pylint: disable=protected-access
   1289 if steps_per_epoch is None:
   1290   # Local training with finite dataset.
   1291 
   1292   # Load the entire training dataset in Yggdrasil in a single TensorFlow
   1293   # step.
-> 1294   self._num_training_examples = self._consumes_training_examples_until_eof(
   1295       iterator)
   1297 else:
   1298   # Local training with number of steps.
   1299 
   1300   # TODO: Make this case an error and remove this code.
   1301   tf_logging.warning(
   1302       "You are using non-distributed training with steps_per_epoch. "
   1303       "This solution will lead to a sub-optimal model. Instead, "
   1304       "use a finite training dataset (e.g. a dataset without "
   1305       "repeat operation) and remove the `steps_per_epoch` argument. "
   1306       "This warning will be turned into an error in the future.")

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow/python/trackable/base.py:205, in no_automatic_dependency_tracking.<locals>._method_wrapper(self, *args, **kwargs)
    203 self._self_setattr_tracking = False  # pylint: disable=protected-access
    204 try:
--> 205   result = method(self, *args, **kwargs)
    206 finally:
    207   self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:1181, in CoreModel._consumes_training_examples_until_eof(self, iterator)
   1179 num_examples = 0
   1180 for data in iterator:
-> 1181   num_examples += self.train_step(data)
   1182 return num_examples

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:593, in CoreModel.train_step(self, data)
    588 @no_automatic_dependency_tracking
    589 @tf.function(reduce_retracing=True)
    590 def train_step(self, data):
    591   """Collects training examples."""
--> 593   return self.collect_data_step(data, is_training_example=True)

File ~/anaconda3/envs/deeplearning/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:691, in CoreModel.collect_data_step(self, data, is_training_example)
    689 if is_training_example:
    690   if self._semantics is not None:
--> 691     raise ValueError("The model is already trained")
    693   # Save the semantic for later re-use.
    694   self._semantics = semantics

ValueError: The model is already trained
rstz commented 1 year ago

Hi,

TF-DF supports weights, but the syntax is a bit different from Keras' class_weights syntax. The weights are stored a separate channel of the input dataset. In practice, this means that your weights should be a column of the pandas dataset you're working with, say my_weights. Then you can tell the pd_dataframe_to_tf_dataset() about this column explicitly.

Furthermore, Random Forests do not require validation datasets by design - they can self-evaluate effectively and do not use early stopping. You can therefore skip the validation dataset.

Try the following:

# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="y_class", weight="my_weights")
test_ds  = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="y_class", weight="my_weights")

# Train the model
model = tfdf.keras.RandomForestModel(num_trees=n_estimators, max_depth=16)
model.fit(train_ds)
josseossa commented 1 year ago

Thanks @rstz for your help… in my case I’m working with binary classification. Do you have any advice on how the weights columns should be defined?

Will be a reasonable action to assign the same weight to each example of the positive and negative class example? I was using this code to calculate the weights based on the balance of my dataset:

class_weights = dict(zip(np.unique(y_class), class_weight.compute_class_weight( class_weight='balanced',
                                                   classes=np.unique(y_class),
                                                   y=y_class)
                        ))
josseossa commented 1 year ago

I removed the weights to try the suggestion and I found the same error. I have already trained a model with a similar dataset, the only difference is that this have new features and maybe more columns. I'm reading the code and I found this function:

def _consumes_training_examples_until_eof(self, iterator):
    """Consumes all the available training examples.

    The examples are either loaded in memory (local training) or indexed
    (distributed training). Returns the total number of consumed examples.

    Args:
      iterator: Iterator of examples.

    Returns:
      Number of read examples.
    """

    num_examples = 0
    for data in iterator:
      num_examples += self.train_step(data)
    return num_examples

As you can see, it's retrieving the num_examples in a for loop, but the function train_step(data) ends up calling collect_data_step() that expects None for self._semantics and the set it up.

Executing the last part into a loop will raise the exception considering that the attribute was defined the first run of the loop and then will fail if there is a second run.

I'm not expert enough to suggest a correction (apart from removing the exception that I did on my side and it worked)... Could this be a bug in the code?

rstz commented 1 year ago

Could you please provide a minimal example of your issue?

josseossa commented 1 year ago

Sure, this is the df structure:

df_var.columns
    [last_pred,
    max_0_10,
    max_0_120,
    max_0_30,
    max_0_300,
    max_0_450,
    max_0_60,
    max_0_600,
    max_10_10,
    max_10_30,
    max_11_10,
    max_11_30,
    max_12_10,
    max_12_30,
    max_13_10,
    max_13_30,
    max_14_10,
    max_14_30,
    max_15_10,
    max_15_30,
    max_16_10,
    max_16_30,
    max_17_10,
    max_17_30,
    max_18_10,
    max_18_30,
    max_19_10,
    max_19_30,
    max_1_10,
    max_1_120,
    max_1_30,
    max_1_300,
    max_1_60,
    max_20_10,
    max_21_10,
    max_22_10,
    max_23_10,
    max_24_10,
    max_25_10,
    max_26_10,
    max_27_10,
    max_28_10,
    max_29_10,
    max_2_10,
    max_2_120,
    max_2_30,
    max_2_60,
    max_30_10,
    max_31_10,
    max_32_10,
    max_33_10,
    max_34_10,
    max_35_10,
    max_36_10,
    max_37_10,
    max_38_10,
    max_39_10,
    max_3_10,
    max_3_120,
    max_3_30,
    max_3_60,
    max_40_10,
    max_41_10,
    max_42_10,
    max_43_10,
    max_44_10,
    max_45_10,
    max_46_10,
    max_47_10,
    max_48_10,
    max_49_10,
    max_4_10,
    max_4_120,
    max_4_30,
    max_4_60,
    max_50_10,
    max_51_10,
    max_52_10,
    max_53_10,
    max_54_10,
    max_55_10,
    max_56_10,
    max_57_10,
    max_58_10,
    max_59_10,
    max_5_10,
    max_5_30,
    max_5_60,
    max_6_10,
    max_6_30,
    max_6_60,
    max_7_10,
    max_7_30,
    max_7_60,
    max_8_10,
    max_8_30,
    max_8_60,
    max_9_10,
    max_9_30,
    max_9_60,
    mean_0_10,
    mean_0_120,
    mean_0_30,
    mean_0_300,
    mean_0_450,
    mean_0_60,
    mean_0_600,
    mean_10_10,
    mean_10_30,
    mean_11_10,
    mean_11_30,
    mean_12_10,
    mean_12_30,
    mean_13_10,
    mean_13_30,
    mean_14_10,
    mean_14_30,
    mean_15_10,
    mean_15_30,
    mean_16_10,
    mean_16_30,
    mean_17_10,
    mean_17_30,
    mean_18_10,
    mean_18_30,
    mean_19_10,
    mean_19_30,
    mean_1_10,
    mean_1_120,
    mean_1_30,
    mean_1_300,
    mean_1_60,
    mean_20_10,
    mean_21_10,
    mean_22_10,
    mean_23_10,
    mean_24_10,
    mean_25_10,
    mean_26_10,
    mean_27_10,
    mean_28_10,
    mean_29_10,
    mean_2_10,
    mean_2_120,
    mean_2_30,
    mean_2_60,
    mean_30_10,
    mean_31_10,
    mean_32_10,
    mean_33_10,
    mean_34_10,
    mean_35_10,
    mean_36_10,
    mean_37_10,
    mean_38_10,
    mean_39_10,
    mean_3_10,
    mean_3_120,
    mean_3_30,
    mean_3_60,
    mean_40_10,
    mean_41_10,
    mean_42_10,
    mean_43_10,
    mean_44_10,
    mean_45_10,
    mean_46_10,
    mean_47_10,
    mean_48_10,
    mean_49_10,
    mean_4_10,
    mean_4_120,
    mean_4_30,
    mean_4_60,
    mean_50_10,
    mean_51_10,
    mean_52_10,
    mean_53_10,
    mean_54_10,
    mean_55_10,
    mean_56_10,
    mean_57_10,
    mean_58_10,
    mean_59_10,
    mean_5_10,
    mean_5_30,
    mean_5_60,
    mean_6_10,
    mean_6_30,
    mean_6_60,
    mean_7_10,
    mean_7_30,
    mean_7_60,
    mean_8_10,
    mean_8_30,
    mean_8_60,
    mean_9_10,
    mean_9_30,
    mean_9_60,
    min_0_10,
    min_0_120,
    min_0_30,
    min_0_300,
    min_0_450,
    min_0_60,
    min_0_600,
    min_10_10,
    min_10_30,
    min_11_10,
    min_11_30,
    min_12_10,
    min_12_30,
    min_13_10,
    min_13_30,
    min_14_10,
    min_14_30,
    min_15_10,
    min_15_30,
    min_16_10,
    min_16_30,
    min_17_10,
    min_17_30,
    min_18_10,
    min_18_30,
    min_19_10,
    min_19_30,
    min_1_10,
    min_1_120,
    min_1_30,
    min_1_300,
    min_1_60,
    min_20_10,
    min_21_10,
    min_22_10,
    min_23_10,
    min_24_10,
    min_25_10,
    min_26_10,
    min_27_10,
    min_28_10,
    min_29_10,
    min_2_10,
    min_2_120,
    min_2_30,
    min_2_60,
    min_30_10,
    min_31_10,
    min_32_10,
    min_33_10,
    min_34_10,
    min_35_10,
    min_36_10,
    min_37_10,
    min_38_10,
    min_39_10,
    min_3_10,
    min_3_120,
    min_3_30,
    min_3_60,
    min_40_10,
    min_41_10,
    min_42_10,
    min_43_10,
    min_44_10,
    min_45_10,
    min_46_10,
    min_47_10,
    min_48_10,
    min_49_10,
    min_4_10,
    min_4_120,
    min_4_30,
    min_4_60,
    min_50_10,
    min_51_10,
    min_52_10,
    min_53_10,
    min_54_10,
    min_55_10,
    min_56_10,
    min_57_10,
    min_58_10,
    min_59_10,
    min_5_10,
    min_5_30,
    min_5_60,
    min_6_10,
    min_6_30,
    min_6_60,
    min_7_10,
    min_7_30,
    min_7_60,
    min_8_10,
    min_8_30,
    min_8_60,
    min_9_10,
    min_9_30,
    min_9_60,
    levels_var_0_back,
    levels_var_100_back,
    levels_var_101_back,
    levels_var_102_back,
    levels_var_103_back,
    levels_var_104_back,
    levels_var_105_back,
    levels_var_106_back,
    levels_var_107_back,
    levels_var_108_back,
    levels_var_109_back,
    levels_var_10_back,
    levels_var_110_back,
    levels_var_111_back,
    levels_var_112_back,
    levels_var_113_back,
    levels_var_114_back,
    levels_var_115_back,
    levels_var_116_back,
    levels_var_117_back,
    levels_var_118_back,
    levels_var_119_back,
    levels_var_11_back,
    levels_var_12_back,
    levels_var_13_back,
    levels_var_14_back,
    levels_var_15_back,
    levels_var_16_back,
    levels_var_17_back,
    levels_var_18_back,
    levels_var_19_back,
    levels_var_1_back,
    levels_var_20_back,
    levels_var_21_back,
    levels_var_22_back,
    levels_var_23_back,
    levels_var_24_back,
    levels_var_25_back,
    levels_var_26_back,
    levels_var_27_back,
    levels_var_28_back,
    levels_var_29_back,
    levels_var_2_back,
    levels_var_30_back,
    levels_var_31_back,
    levels_var_32_back,
    levels_var_33_back,
    levels_var_34_back,
    levels_var_35_back,
    levels_var_36_back,
    levels_var_37_back,
    levels_var_38_back,
    levels_var_39_back,
    levels_var_3_back,
    levels_var_40_back,
    levels_var_41_back,
    levels_var_42_back,
    levels_var_43_back,
    levels_var_44_back,
    levels_var_45_back,
    levels_var_46_back,
    levels_var_47_back,
    levels_var_48_back,
    levels_var_49_back,
    levels_var_4_back,
    levels_var_50_back,
    levels_var_51_back,
    levels_var_52_back,
    levels_var_53_back,
    levels_var_54_back,
    levels_var_55_back,
    levels_var_56_back,
    levels_var_57_back,
    levels_var_58_back,
    levels_var_59_back,
    levels_var_5_back,
    levels_var_60_back,
    levels_var_61_back,
    levels_var_62_back,
    levels_var_63_back,
    levels_var_64_back,
    levels_var_65_back,
    levels_var_66_back,
    levels_var_67_back,
    levels_var_68_back,
    levels_var_69_back,
    levels_var_6_back,
    levels_var_70_back,
    levels_var_71_back,
    levels_var_72_back,
    levels_var_73_back,
    levels_var_74_back,
    levels_var_75_back,
    levels_var_76_back,
    levels_var_77_back,
    levels_var_78_back,
    levels_var_79_back,
    levels_var_7_back,
    levels_var_80_back,
    levels_var_81_back,
    levels_var_82_back,
    levels_var_83_back,
    levels_var_84_back,
    levels_var_85_back,
    levels_var_86_back,
    levels_var_87_back,
    levels_var_88_back,
    levels_var_89_back,
    levels_var_8_back,
    levels_var_90_back,
    levels_var_91_back,
    levels_var_92_back,
    levels_var_93_back,
    levels_var_94_back,
    levels_var_95_back,
    levels_var_96_back,
    levels_var_97_back,
    levels_var_98_back,
    levels_var_99_back,
    levels_var_9_back,
        y_class]

len(df_var)
10221999

with this I run:

train_ratio = 0.9
train_df = df_var[:math.floor(len(df_var)*train_ratio)]
test_df = df_var[-(len(df_var) - len(train_df) + steps):]

# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="y_class")
test_ds  = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="y_class")

# Train the model
model = tfdf.keras.RandomForestModel(num_trees=n_estimators, max_depth=16)
model.fit(train_ds)

The fit method raise ValueError: The model is already trained You can see the full traceback in my original message.

BTW: I ran the same commands but with only 121 columns without issues, It's only by the addition of the 200+ columns that I'm getting the issue.

rstz commented 1 year ago

This is an interesting issue and I would like to get to the bottom of it.

I adapted the code you shared to run with a randomly generated dataframe, but it works just fine. Can you please create a Google Colab that reproduces the error and share it?

col_list = [
  'last_pred',
  'max_0_10',
  'max_0_120',
  'max_0_30',
  'max_0_300',
  'max_0_450',
  'max_0_60',
  'max_0_600',
  'max_10_10',
  'max_10_30',
  'max_11_10',
  'max_11_30',
  'max_12_10',
  'max_12_30',
  'max_13_10',
  'max_13_30',
  'max_14_10',
  'max_14_30',
  'max_15_10',
  'max_15_30',
  'max_16_10',
  'max_16_30',
  'max_17_10',
  'max_17_30',
  'max_18_10',
  'max_18_30',
  'max_19_10',
  'max_19_30',
  'max_1_10',
  'max_1_120',
  'max_1_30',
  'max_1_300',
  'max_1_60',
  'max_20_10',
  'max_21_10',
  'max_22_10',
  'max_23_10',
  'max_24_10',
  'max_25_10',
  'max_26_10',
  'max_27_10',
  'max_28_10',
  'max_29_10',
  'max_2_10',
  'max_2_120',
  'max_2_30',
  'max_2_60',
  'max_30_10',
  'max_31_10',
  'max_32_10',
  'max_33_10',
  'max_34_10',
  'max_35_10',
  'max_36_10',
  'max_37_10',
  'max_38_10',
  'max_39_10',
  'max_3_10',
  'max_3_120',
  'max_3_30',
  'max_3_60',
  'max_40_10',
  'max_41_10',
  'max_42_10',
  'max_43_10',
  'max_44_10',
  'max_45_10',
  'max_46_10',
  'max_47_10',
  'max_48_10',
  'max_49_10',
  'max_4_10',
  'max_4_120',
  'max_4_30',
  'max_4_60',
  'max_50_10',
  'max_51_10',
  'max_52_10',
  'max_53_10',
  'max_54_10',
  'max_55_10',
  'max_56_10',
  'max_57_10',
  'max_58_10',
  'max_59_10',
  'max_5_10',
  'max_5_30',
  'max_5_60',
  'max_6_10',
  'max_6_30',
  'max_6_60',
  'max_7_10',
  'max_7_30',
  'max_7_60',
  'max_8_10',
  'max_8_30',
  'max_8_60',
  'max_9_10',
  'max_9_30',
  'max_9_60',
  'mean_0_10',
  'mean_0_120',
  'mean_0_30',
  'mean_0_300',
  'mean_0_450',
  'mean_0_60',
  'mean_0_600',
  'mean_10_10',
  'mean_10_30',
  'mean_11_10',
  'mean_11_30',
  'mean_12_10',
  'mean_12_30',
  'mean_13_10',
  'mean_13_30',
  'mean_14_10',
  'mean_14_30',
  'mean_15_10',
  'mean_15_30',
  'mean_16_10',
  'mean_16_30',
  'mean_17_10',
  'mean_17_30',
  'mean_18_10',
  'mean_18_30',
  'mean_19_10',
  'mean_19_30',
  'mean_1_10',
  'mean_1_120',
  'mean_1_30',
  'mean_1_300',
  'mean_1_60',
  'mean_20_10',
  'mean_21_10',
  'mean_22_10',
  'mean_23_10',
  'mean_24_10',
  'mean_25_10',
  'mean_26_10',
  'mean_27_10',
  'mean_28_10',
  'mean_29_10',
  'mean_2_10',
  'mean_2_120',
  'mean_2_30',
  'mean_2_60',
  'mean_30_10',
  'mean_31_10',
  'mean_32_10',
  'mean_33_10',
  'mean_34_10',
  'mean_35_10',
  'mean_36_10',
  'mean_37_10',
  'mean_38_10',
  'mean_39_10',
  'mean_3_10',
  'mean_3_120',
  'mean_3_30',
  'mean_3_60',
  'mean_40_10',
  'mean_41_10',
  'mean_42_10',
  'mean_43_10',
  'mean_44_10',
  'mean_45_10',
  'mean_46_10',
  'mean_47_10',
  'mean_48_10',
  'mean_49_10',
  'mean_4_10',
  'mean_4_120',
  'mean_4_30',
  'mean_4_60',
  'mean_50_10',
  'mean_51_10',
  'mean_52_10',
  'mean_53_10',
  'mean_54_10',
  'mean_55_10',
  'mean_56_10',
  'mean_57_10',
  'mean_58_10',
  'mean_59_10',
  'mean_5_10',
  'mean_5_30',
  'mean_5_60',
  'mean_6_10',
  'mean_6_30',
  'mean_6_60',
  'mean_7_10',
  'mean_7_30',
  'mean_7_60',
  'mean_8_10',
  'mean_8_30',
  'mean_8_60',
  'mean_9_10',
  'mean_9_30',
  'mean_9_60',
  'min_0_10',
  'min_0_120',
  'min_0_30',
  'min_0_300',
  'min_0_450',
  'min_0_60',
  'min_0_600',
  'min_10_10',
  'min_10_30',
  'min_11_10',
  'min_11_30',
  'min_12_10',
  'min_12_30',
  'min_13_10',
  'min_13_30',
  'min_14_10',
  'min_14_30',
  'min_15_10',
  'min_15_30',
  'min_16_10',
  'min_16_30',
  'min_17_10',
  'min_17_30',
  'min_18_10',
  'min_18_30',
  'min_19_10',
  'min_19_30',
  'min_1_10',
  'min_1_120',
  'min_1_30',
  'min_1_300',
  'min_1_60',
  'min_20_10',
  'min_21_10',
  'min_22_10',
  'min_23_10',
  'min_24_10',
  'min_25_10',
  'min_26_10',
  'min_27_10',
  'min_28_10',
  'min_29_10',
  'min_2_10',
  'min_2_120',
  'min_2_30',
  'min_2_60',
  'min_30_10',
  'min_31_10',
  'min_32_10',
  'min_33_10',
  'min_34_10',
  'min_35_10',
  'min_36_10',
  'min_37_10',
  'min_38_10',
  'min_39_10',
  'min_3_10',
  'min_3_120',
  'min_3_30',
  'min_3_60',
  'min_40_10',
  'min_41_10',
  'min_42_10',
  'min_43_10',
  'min_44_10',
  'min_45_10',
  'min_46_10',
  'min_47_10',
  'min_48_10',
  'min_49_10',
  'min_4_10',
  'min_4_120',
  'min_4_30',
  'min_4_60',
  'min_50_10',
  'min_51_10',
  'min_52_10',
  'min_53_10',
  'min_54_10',
  'min_55_10',
  'min_56_10',
  'min_57_10',
  'min_58_10',
  'min_59_10',
  'min_5_10',
  'min_5_30',
  'min_5_60',
  'min_6_10',
  'min_6_30',
  'min_6_60',
  'min_7_10',
  'min_7_30',
  'min_7_60',
  'min_8_10',
  'min_8_30',
  'min_8_60',
  'min_9_10',
  'min_9_30',
  'min_9_60',
  'levels_var_0_back',
  'levels_var_100_back',
  'levels_var_101_back',
  'levels_var_102_back',
  'levels_var_103_back',
  'levels_var_104_back',
  'levels_var_105_back',
  'levels_var_106_back',
  'levels_var_107_back',
  'levels_var_108_back',
  'levels_var_109_back',
  'levels_var_10_back',
  'levels_var_110_back',
  'levels_var_111_back',
  'levels_var_112_back',
  'levels_var_113_back',
  'levels_var_114_back',
  'levels_var_115_back',
  'levels_var_116_back',
  'levels_var_117_back',
  'levels_var_118_back',
  'levels_var_119_back',
  'levels_var_11_back',
  'levels_var_12_back',
  'levels_var_13_back',
  'levels_var_14_back',
  'levels_var_15_back',
  'levels_var_16_back',
  'levels_var_17_back',
  'levels_var_18_back',
  'levels_var_19_back',
  'levels_var_1_back',
  'levels_var_20_back',
  'levels_var_21_back',
  'levels_var_22_back',
  'levels_var_23_back',
  'levels_var_24_back',
  'levels_var_25_back',
  'levels_var_26_back',
  'levels_var_27_back',
  'levels_var_28_back',
  'levels_var_29_back',
  'levels_var_2_back',
  'levels_var_30_back',
  'levels_var_31_back',
  'levels_var_32_back',
  'levels_var_33_back',
  'levels_var_34_back',
  'levels_var_35_back',
  'levels_var_36_back',
  'levels_var_37_back',
  'levels_var_38_back',
  'levels_var_39_back',
  'levels_var_3_back',
  'levels_var_40_back',
  'levels_var_41_back',
  'levels_var_42_back',
  'levels_var_43_back',
  'levels_var_44_back',
  'levels_var_45_back',
  'levels_var_46_back',
  'levels_var_47_back',
  'levels_var_48_back',
  'levels_var_49_back',
  'levels_var_4_back',
  'levels_var_50_back',
  'levels_var_51_back',
  'levels_var_52_back',
  'levels_var_53_back',
  'levels_var_54_back',
  'levels_var_55_back',
  'levels_var_56_back',
  'levels_var_57_back',
  'levels_var_58_back',
  'levels_var_59_back',
  'levels_var_5_back',
  'levels_var_60_back',
  'levels_var_61_back',
  'levels_var_62_back',
  'levels_var_63_back',
  'levels_var_64_back',
  'levels_var_65_back',
  'levels_var_66_back',
  'levels_var_67_back',
  'levels_var_68_back',
  'levels_var_69_back',
  'levels_var_6_back',
  'levels_var_70_back',
  'levels_var_71_back',
  'levels_var_72_back',
  'levels_var_73_back',
  'levels_var_74_back',
  'levels_var_75_back',
  'levels_var_76_back',
  'levels_var_77_back',
  'levels_var_78_back',
  'levels_var_79_back',
  'levels_var_7_back',
  'levels_var_80_back',
  'levels_var_81_back',
  'levels_var_82_back',
  'levels_var_83_back',
  'levels_var_84_back',
  'levels_var_85_back',
  'levels_var_86_back',
  'levels_var_87_back',
  'levels_var_88_back',
  'levels_var_89_back',
  'levels_var_8_back',
  'levels_var_90_back',
  'levels_var_91_back',
  'levels_var_92_back',
  'levels_var_93_back',
  'levels_var_94_back',
  'levels_var_95_back',
  'levels_var_96_back',
  'levels_var_97_back',
  'levels_var_98_back',
  'levels_var_99_back',
  'levels_var_9_back',
  'y_class'
    ]

df_var = pd.DataFrame(np.random.randint(10, size=(5, len(col_list))), columns=col_list)

len(df_var)

steps = 5
n_estimators = 10

train_ratio = 0.9
train_df = df_var[:math.floor(len(df_var)*train_ratio)]
test_df = df_var[-(len(df_var) - len(train_df) + steps):]

# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="y_class")
test_ds  = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="y_class")

# Train the model
model = tfdf.keras.RandomForestModel(num_trees=n_estimators, max_depth=16)
model.fit(train_ds)
josseossa commented 1 year ago

Hi Richard, thanks for the interest. I don't think the DataFrame will fit in a free runtime (considering the memory limitations) but I will git it a try. Take in consideration the length of the DataFrame, that contains 10+ millions of examples. I'll come back as soon as I have the example to show you either in collab or another site.

rstz commented 1 year ago

Great! If colab is not the right tool, feel free use another way to share a debuggable example :)

rstz commented 1 year ago

I was able to train a RF model with 10000000 random integer examples and your feature names with TF-DF without issues.

josseossa commented 1 year ago

Hi Richard, just to let you know that I'm setting up a Jupyter service to let you check the issue. As soon as I have it public I'll let you know.

josseossa commented 1 year ago

Hi Richard, I think there was an issue in my env because by creating a new one and installing the packages it works now. Thanks a lot for your interest and you help, it's very appreciated.

rstz commented 1 year ago

Cool, thank you for letting me know. Happy to answer more questions if needed. If you happen to have any success stories with TF-DF one day, we'd be happy to learn about them!