dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
8.91k stars 1.86k forks source link

Get Loss During Training for Visualization (Learning Curve Graph) #7140

Open chrisevans9629 opened 2 months ago

chrisevans9629 commented 2 months ago

Is your feature request related to a problem? Please describe. I need a way to visualize how my model is learning during training, which is a comparison between training loss and test loss.

Describe the solution you'd like An event handler that enables the ability to extract loss during training.

Describe alternatives you've considered Running the model for x epochs, evaluating the model, then retraining the model in a loop. This unfortunately does not work for all models, such as LightGbm that can't be retrained.

var kfold = ctx.BinaryClassification.CrossValidate(training, estimator, param.kfold);
var bestModel = kfold.OrderByDescending(p => p.Metrics.Accuracy).Select(p => p.Model).First();

var testOutput = bestModel.Transform(test);

var metrics = ctx.BinaryClassification.Evaluate(testOutput);

// This code doesn't work as estimator is IEstimator<ITransform> and bestModel is ITransform. Not sure how you would do this...
estimator = bestModel;

Additional context Ultimately, I am trying to analyze what the models I'm comparing are actually doing and so far I haven't found any documentation or any straightforward way to do it. image