dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
9.01k stars 1.88k forks source link

How to predict text type based on input text? #7134

Closed CodingOctocat closed 5 months ago

CodingOctocat commented 5 months ago

I'm a newbie and I read the tutorial github-issue-classification, but I don't quite understand it, and it seems that my question should be classification.

My specific requirement is to develop a database query function for a single text box (search box), where the program automatically predicts that the input keyword belongs to a certain column of data in the database and filters the query against it. For example:

Here is my data model where ProductCategory, ProductName, Models, Enterprise, CertificateNumber, ReportNumbers are the fields that I need to involve in the training, note that Models, ReportNumbers are list types. ReportNumbers` are list types.

My database uses Ef Core/SQLite for storage and has a record size of about 150,000 rows.

Can you all help me to implement this feature? I want to learn and understand ML.NET with this example.

[Table("FooSet")]
[Index(nameof(CertificateNumber), nameof(ReportNumbers), nameof(Models), nameof(Enterprise), nameof(ProductName))]
public partial record Foo
{
    private string _productCategory = "";

    [Column("cert_exp")]
    public DateOnly CertificateExpires { get; set; }

    [Column("cert_iss")]
    public DateOnly CertificateIssue { get; set; }

    [Key]
    [Column("cert_no")]
    public string CertificateNumber { get; set; }

    [Column("ent")]
    public string Enterprise { get; set; }

    [Column("models")]
    public List<string> Models { get; set; }

    [Column("cat")]
    public string ProductCategory
    {
        get => _productCategory ?? "";
        set => _productCategory = value ?? "";
    }

    [Column("name")]
    public string ProductName { get; set; }

    [Column("rpt_nos")]
    public List<string> ReportNumbers { get; set; }

    [Column("status")]
    public string Status { get; set; }

    public Foo(string productCategory, string productName, List<string> models, string enterprise,
        string certificateNumber, List<string> reportNumbers,
        DateOnly certificateIssue, DateOnly certificateExpires, string status)
    {
        ProductCategory = productCategory;
        ProductName = productName;
        Models = models;
        Enterprise = enterprise;
        CertificateNumber = certificateNumber;
        ReportNumbers = reportNumbers;

        CertificateIssue = certificateIssue;
        CertificateExpires = certificateExpires;
        Status = status;
    }
}

refer link: machinelearning-samples/issues/1028

CodingOctocat commented 5 months ago

I found a solution. The focus is on the LoadData() part.

public class FooFieldClassifier
{
    private readonly FooDbContext _FooDbContext;

    private readonly IConfiguration _configuration;

    private readonly MLContext _mlContext;

    private readonly Lazy<ITransformer> _trainedModelLazy;

    private ITransformer? _trainedModel;

    private TrainTestData _trainTestData;

    public bool IsTrainedModelLoaded => _trainedModelLazy.IsValueCreated;

    public string TrainedModelPath => _configuration["TrainedModelPath"] ?? "Foo.model.zip";

    public FooFieldClassifier(MLContext mlContext, FooDbContext FooDbContext, IConfiguration configuration)
    {
        _mlContext = mlContext;
        _FooDbContext = FooDbContext;
        _configuration = configuration;

        _trainedModelLazy = new Lazy<ITransformer>(() => _mlContext.Model.Load(TrainedModelPath, out _));
    }

    public void Evaluate()
    {
        Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now} ===============");

        var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel?.Transform(_trainTestData.TestSet));

        Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now} ===============");
        Debug.WriteLine($"*************************************************************************************************************");
        Debug.WriteLine($"*       Metrics for Multi-class Classification model - Test Data     ");
        Debug.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Debug.WriteLine($"*       MicroAccuracy:    {testMetrics.MicroAccuracy:0.###}");
        Debug.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
        Debug.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
        Debug.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
        Debug.WriteLine($"*************************************************************************************************************");
    }

    public void LoadModel()
    {
        _trainedModel ??= _trainedModelLazy.Value;
    }

    public FooFieldTypePrediction Predict(string field)
    {
        LoadModel();

        var example = new FooModelInput(field);

        var predEngine = _mlContext.Model.CreatePredictionEngine<FooModelInput, FooFieldTypePrediction>(_trainedModel);

        var prediction = predEngine.Predict(example);

        Debug.WriteLine($"=============== Single Prediction - Result: {field}: {prediction.FooFieldType} ===============");

        return prediction;
    }

    public FooFieldType PredictFooFieldType(string field)
    {
        field = field.CleanText();

        if (String.IsNullOrWhiteSpace(field))
        {
            return FooFieldType.SmartMode;
        }

        if (FooCertificateStatusFields.Descriptions.Value.Contains(field))
        {
            return FooFieldType.Status;
        }

        if (DateOnly.TryParse(field, out _))
        {
            return FooFieldType.CertDateStart;
        }

        if (FooTestingCenterFields.Descriptions.Value.Contains(field))
        {
            return FooFieldType.TestingCenter;
        }

        var prediction = Predict(field);

        return prediction.FooFieldType;
    }

    public void SaveModelAsFile(ITransformer model, DataViewSchema trainingDataViewSchema)
    {
        _mlContext.Model.Save(model, trainingDataViewSchema, TrainedModelPath);

        Debug.WriteLine($"The model is saved to {TrainedModelPath}");
    }

    public void Train()
    {
        _trainTestData = LoadData();
        var pipeline = ProcessData();
        BuildAndTrainModel(_trainTestData.TrainSet, pipeline);
    }

    private TransformerChain<KeyToValueMappingTransformer> BuildAndTrainModel(IDataView splitTrainSet, IEstimator<ITransformer> pipeline)
    {
        var trainingPipeline = pipeline
            .Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Feature"))
            .Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

        var trainedModel = trainingPipeline.Fit(splitTrainSet);

        return trainedModel;
    }

    private TrainTestData LoadData()
    {
        var Foos = _FooDbContext.FooSet.AsEnumerable();

        var cats = Foos.Select(x => new FooModelInput(x.ProductCategory, FooFieldType.ProductCategory));
        var ents = Foos.Select(x => new FooModelInput(x.Enterprise, FooFieldType.EnterpriseName));
        var names = Foos.Select(x => new FooModelInput(x.ProductName, FooFieldType.ProductName));
        var models = Foos.SelectMany(x => x.Models).Select(x => new FooModelInput(x, FooFieldType.Model));
        var certs = Foos.Select(x => new FooModelInput(x.CertificateNumber, FooFieldType.CertificateNo));
        var rpts = Foos.SelectMany(x => x.ReportNumbers).Select(x => new FooModelInput(x, FooFieldType.ReportNo));

        var modelInputs = new[] { cats, ents, names, models, certs, rpts }.SelectMany(x => x);

        var dataView = _mlContext.Data.LoadFromEnumerable(modelInputs);

        var splitDataView = _mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
        var trainData = splitDataView.TrainSet;
        var testData = splitDataView.TestSet;

        return splitDataView;
    }

    private EstimatorChain<ITransformer> ProcessData()
    {
        var pipeline = _mlContext.Transforms.Conversion
            .MapValueToKey(inputColumnName: nameof(FooModelInput.FooFieldType), outputColumnName: "Label")
            .Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: nameof(FooModelInput.Field), outputColumnName: "Feature"))
            .AppendCacheCheckpoint(_mlContext);

        return pipeline;
    }
}
public class FooModelInput
{
    [ColumnName("Label")]
    public int FooFieldType { get; set; }

    public string Field { get; set; }

    public FooModelInput(string field, FooFieldType FooFieldType)
    {
        Field = field;
        FooFieldType = (int)FooFieldType;
    }

    public FooModelInput(string field)
    {
        Field = field;
    }
}
public class FooFieldTypePrediction
{
    public FooFieldType FooFieldType => (FooFieldType)Prediction;

    [ColumnName("PredictedLabel")]
    public int Prediction { get; set; }
}