Closed marcromeyn closed 1 year ago
This PR introduces a custom Trainer object that reduces boilerplate. With this, a model can be trained as follows:
Trainer
model = mm.Model( TabularInputBlock(schema), MLPBlock([32, 16]), BinaryOutput(schema.select_by_tag(Tags.TARGET).first), ) trainer = mm.Trainer(max_epochs=1) trainer.fit(model, dataset, batch_size=16)
https://nvidia-merlin.github.io/models/review/pr-1176
Goals :soccer:
This PR introduces a custom
Trainer
object that reduces boilerplate. With this, a model can be trained as follows: