dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
9.05k stars 1.89k forks source link

CreatePredictionEngine does not work after retrain the modal without column Transforms #4383

Closed volitantns closed 4 years ago

volitantns commented 5 years ago

System information

Issue

I saved the LinearBinaryModelParameters , and retrain the modal as :

transf = mlc.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(newdata, lbmparameters);

and call

pe = mlc.Model.CreatePredictionEngine<ViewItem, ViewPred>(transf);

Source code / logs

using System;
using System.Linq;
using System.IO;
using Microsoft.ML;
using System.Collections.Generic;

using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

class Program
{
    static bool STUDY_CONTINUE_MODE = true;

    const int MEM_SIZE = 9;

    static Random rand = new Random();

    static int total_study = 0;

    static List<ViewItem> happymem = new List<ViewItem>();
    static List<ViewItem> dangermem = new List<ViewItem>();
    static int GetTotalMemCount()
    {
        return happymem.Count + dangermem.Count;
    }

    static string feeling = new string('0', MEM_SIZE);

    static MLContext mlc = new MLContext();

    static long totalticks = 0;

    static long totalright = 0;
    static long totalwrong = 0;

    static void Main()
    {

        while (true)
        {
            RunTick();
            totalticks++;
            if (totalticks % 10000 == 0 && transf != null && totalright + totalwrong > 2)
            {
                Console.Title = totalticks + " - "
                    + " " + 100 * safe_mem / totalmem + "%"
                    + " CV=" + 100 * totalright / (totalright + totalwrong) + "%"
                    + " " + totalright + "/" + totalwrong
                    + " PV=" + 100 * predright / Math.Max(1, predtimes) + "%"
                    + " " + predright + "/" + (predtimes - predright)
                    + " HQ=" + Hao_Qi.ToString("00000")

                    + " ML=" + total_study + " " + (100 * dangermem.Count / GetTotalMemCount()) + "%" + " " + dangermem.Count + "/" + GetTotalMemCount()
                      ;
            }
        }

    }

    static ITransformer transf;
    static ITransformer firsttransf;
    static LinearBinaryModelParameters lbmparameters = null;

    static void StudyOnce(ViewItem[] items)
    {
        var data = mlc.Data.LoadFromEnumerable(items);

        if (lbmparameters == null)
        {
            if (!STUDY_CONTINUE_MODE)
                mlc = new MLContext();

            var mytransf = mlc.Transforms.Text.FeaturizeText("Features", "TextData")
               .Append(mlc.BinaryClassification.Trainers.LbfgsLogisticRegression()).Fit(data);

            if (STUDY_CONTINUE_MODE)
                lbmparameters = mytransf.LastTransformer.Model.SubModel;

            firsttransf = mytransf;
            transf = mytransf;
        }
        else
        {
            //TODO:可以重新训练模型,但是在使用上存在问题.
            //因为使用了.Fit(newdata, lbmparameters)生成的Transformer不具备 mlc.Transforms.Text.FeaturizeText("Features", "TextData") 这个步骤.
            //这会导致 CreatePredictionEngine(transf) 生成的函数无法把ViewItem.TextData转换为Features 

            var newdata = firsttransf.Transform(data);

            var mytransf = mlc.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(newdata, lbmparameters);

            if (STUDY_CONTINUE_MODE)
                lbmparameters = mytransf.Model.SubModel;

            transf = mytransf;
        }

    }

    static void StudyOnce()
    {
        ViewItem[] items = new ViewItem[GetTotalMemCount()];
        happymem.CopyTo(items, 0);
        dangermem.CopyTo(items, happymem.Count);
        StudyOnce(items);
    }

    class ViewItem
    {
        public string TextData;

        [ColumnName("Label")]
        public bool IsDanger;
    }
    class ViewPred
    {
        public bool PredictedLabel;

        [ColumnName("Score")]
        public float ScoreValue;

        public float Probability;
    }

    static int Hao_Qi = 100;

    static long totalmem = 0;
    static long safe_mem = 0;

    static bool feeling_isdanger;

    static void RunTick()
    {
        totalmem++;

        feeling = feeling.Remove(0, 1) + rand.Next(0, 10);

        feeling_isdanger = feeling.Contains("9");

        if (!feeling_isdanger)
            safe_mem++;

        Hao_Qi += 10;

        if (Hao_Qi > 100)
        {
            float danger = GetDangerLevel();
            if (Hao_Qi > 20000 * danger)
            {
                DoClick();
            }
        }

    }

    static long predtimes = 0;
    static long predright = 0;

    static PredictionEngine<ViewItem, ViewPred> pe;
    static ITransformer petransf;

    static float GetDangerLevel()
    {
        if (transf == null)
            return 0;

        if (pe == null || petransf != transf)
        {
            if (pe != null)
                pe.Dispose();//prevent memory leak ???
            pe = mlc.Model.CreatePredictionEngine<ViewItem, ViewPred>(transf);
            petransf = transf;
        }

        ViewPred pred = pe.Predict(new ViewItem() { TextData = feeling });

        bool isdanger = pred.PredictedLabel;

        predtimes++;
        if (isdanger == feeling_isdanger)
        {
            predright++;
        }

        return pred.Probability;
    }

    static void DoClick()
    {
        if (!feeling_isdanger)
        {
            Hao_Qi -= 40;
            totalright++;

            happymem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }
        else
        {
            Hao_Qi -= 90;
            totalwrong++;

            dangermem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }

        if (happymem.Count > maxcount || dangermem.Count > maxcount)
        {
            total_study++;
            StudyOnce();

            if (happymem.Count > maxcount)
                happymem.RemoveRange(0, happymem.Count / 10);
            if (dangermem.Count > maxcount)
                dangermem.RemoveRange(0, dangermem.Count / 10);
            if (maxcount < 20000) maxcount = (int)(maxcount * 1.2);
        }

    }

    static int maxcount = 999;

}
volitantns commented 5 years ago

I can use the seperated pipeline to calculate the result ,

But it's very slow , only 1% peformance of the PredictionEngine

using System;
using System.Linq;
using System.IO;
using Microsoft.ML;
using System.Collections.Generic;

using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

class Program
{
    static bool STUDY_CONTINUE_MODE = true;

    const int MEM_SIZE = 9;

    static Random rand = new Random();

    static int total_study = 0;

    static List<ViewItem> happymem = new List<ViewItem>();
    static List<ViewItem> dangermem = new List<ViewItem>();
    static int GetTotalMemCount()
    {
        return happymem.Count + dangermem.Count;
    }

    static string feeling = new string('0', MEM_SIZE);

    static MLContext mlc = new MLContext();

    static long totalticks = 0;

    static long totalright = 0;
    static long totalwrong = 0;

    static void Main()
    {

        while (true)
        {
            RunTick();
            totalticks++;
            if (totalticks % 10000 == 0 && transf != null && totalright + totalwrong > 2)
            {
                Console.Title = totalticks + " - "
                    + " " + 100 * safe_mem / totalmem + "%"
                    + " CV=" + 100 * totalright / (totalright + totalwrong) + "%"
                    + " " + totalright + "/" + totalwrong
                    + " PV=" + 100 * predright / Math.Max(1, predtimes) + "%"
                    + " " + predright + "/" + (predtimes - predright)
                    + " HQ=" + Hao_Qi.ToString("00000")

                    + " ML=" + total_study + " " + (100 * dangermem.Count / GetTotalMemCount()) + "%" + " " + dangermem.Count + "/" + GetTotalMemCount()
                      ;
            }
        }

    }

    class ViewItem
    {
        public string TextData;

        [ColumnName("Label")]
        public bool IsDanger;
    }
    class ViewPred
    {
        public bool PredictedLabel;

        [ColumnName("Score")]
        public float ScoreValue;

        public float Probability;
    }

    static int Hao_Qi = 100;

    static long totalmem = 0;
    static long safe_mem = 0;

    static bool feeling_isdanger;

    static void RunTick()
    {
        totalmem++;

        feeling = feeling.Remove(0, 1) + rand.Next(0, 10);

        feeling_isdanger = feeling.Contains("9");

        if (!feeling_isdanger)
            safe_mem++;

        Hao_Qi += 10;

        float danger = GetDangerLevel();
        if (Hao_Qi > 20000 * danger)
        {
            DoClick();
        }
    }

    static long predtimes = 0;
    static long predright = 0;

    static PredictionEngine<ViewItem, ViewPred> pe;
    static ITransformer petransf;

    static float GetDangerLevel()
    {
        if (transf == null)
            return 0;

        ViewPred pred;

        if (!STUDY_CONTINUE_MODE)
        {
            if (pe == null || petransf != transf)
            {
                if (pe != null)
                    pe.Dispose();//prevent memory leak ???
                pe = mlc.Model.CreatePredictionEngine<ViewItem, ViewPred>(transf);
                petransf = transf;
            }

            pred = pe.Predict(new ViewItem() { TextData = feeling });

        }
        else
        {
            var data = datatransf.Transform(mlc.Data.LoadFromEnumerable(new ViewItem[] { new ViewItem() { TextData = feeling } }));
            data = transf.Transform(data);
            pred = new ViewPred();
            pred.PredictedLabel = data.GetColumn<bool>("PredictedLabel").First();
            pred.Probability = data.GetColumn<float>("Probability").First();
        }

        bool isdanger = pred.PredictedLabel;

        predtimes++;
        if (isdanger == feeling_isdanger)
        {
            predright++;
        }

        return pred.Probability;
    }

    static ITransformer transf;
    static ITransformer datatransf;
    static LinearBinaryModelParameters lbmparameters = null;

    static void StudyOnce(ViewItem[] items)
    {
        var data = mlc.Data.LoadFromEnumerable(items);

        if (!STUDY_CONTINUE_MODE)
        {
            mlc = new MLContext();
            transf = mlc.Transforms.Text.FeaturizeText("Features", "TextData")
               .Append(mlc.BinaryClassification.Trainers.LbfgsLogisticRegression()).Fit(data);

        }
        else
        {
            if (datatransf == null)
                datatransf = mlc.Transforms.Text.FeaturizeText("Features", "TextData").Fit(data);
            data = datatransf.Transform(data);

            var trainer = mlc.BinaryClassification.Trainers.LbfgsLogisticRegression();

            var mytransf = lbmparameters == null ? trainer.Fit(data) : trainer.Fit(data, lbmparameters);

            lbmparameters = mytransf.Model.SubModel;

            transf = mytransf;
        }
    }

    static void StudyOnce()
    {
        ViewItem[] items = new ViewItem[GetTotalMemCount()];
        happymem.CopyTo(items, 0);
        dangermem.CopyTo(items, happymem.Count);
        StudyOnce(items);
    }

    static void DoClick()
    {
        if (!feeling_isdanger)
        {
            Hao_Qi -= 40;
            totalright++;

            happymem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }
        else
        {
            Hao_Qi -= 90;
            totalwrong++;

            dangermem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }

        if (happymem.Count > maxcount || dangermem.Count > maxcount)
        {
            total_study++;
            StudyOnce();

            if (happymem.Count > maxcount)
                happymem.RemoveRange(0, happymem.Count / 10);
            if (dangermem.Count > maxcount)
                dangermem.RemoveRange(0, dangermem.Count / 10);
            if (maxcount < 20000) maxcount = (int)(maxcount * 1.2);
        }

    }

    static int maxcount = 999;

}
antoniovs1029 commented 4 years ago

@volitantns Can you please provide a dataset to test this issue? Thanks!

antoniovs1029 commented 4 years ago

Hi @volitantns . Sorry for the delay. I've just realized that no dataset was needed, since you're actually generating data inside your code.

I've taken a closer look and the problem isn't that you're retraining. CreatePredictionEngine does work with retrained models, so this isn't a bug. The problem is that the model that you're passing to the CreatePredictionEngine isn't the correct one.

To exemplify why there's an error in your code, I've made this sample, that replicates your issue. I will later post an update to your original code, fixing it, following the same basic principle of this sample. Notice that in the sample, using model2 to create the PE won't work, as it doesn't have the transformers to create the Features' column it needs to work... model3 on the other hand does have the transformer to do that, so there's no problem there. The PE works as expected, even if model3 was created using a retrained transformer.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace is4383myrepro
{
    public class DataPoint
    {
        public bool Label { get; set; }
        [VectorType(50)]
        public float[] X { get; set; }

        [VectorType(50)]
        public float[] Y { get; set; }
    }

    public class Prediction
    {
        public bool Label { get; set; }

        public bool PredictedLabel { get; set; }

        public float Probability { get; set; }
    }

    public static class Program
    {
        public static void Main()
        {
            var mlContext = new MLContext(seed: 0);

            var dataPoints = GenerateRandomDataPoints(1000);
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            var concatenatorEstimator = mlContext.Transforms.Concatenate("Features", "X", "Y");

            // Original pipeline and PE
            var pipeline = concatenatorEstimator.Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());
            var model = pipeline.Fit(trainingData);
            var pe = mlContext.Model.CreatePredictionEngine<DataPoint, Prediction>(model);
            var pred = pe.Predict(dataPoints.First());

            // This will throw, because model2 won't have the necessary steps to create the Features vector
            // This is what you did on your original code.
            var trainer = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression();
            var preprocessedData = concatenatorEstimator.Fit(trainingData).Transform(trainingData);
            var model2 = trainer.Fit(preprocessedData, model.LastTransformer.Model.SubModel);
            // pe = mlContext.Model.CreatePredictionEngine<DataPoint, Prediction>(model2); // Throw because it can't find Features Column

            // This won't throw, because model3 will actually contain the transformer to create the Features vector
            var concatenatorTransformer = concatenatorEstimator.Fit(trainingData);
            var model3 = concatenatorTransformer.Append(model2);
            pe = mlContext.Model.CreatePredictionEngine<DataPoint, Prediction>(model3); // Won't throw
            var pred2 = pe.Predict(dataPoints.First()); // Result is as expected
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            float randomFloat() => (float)random.NextDouble();
            for (int i = 0; i < count; i++)
            {
                var label = randomFloat() > 0.5f;
                yield return new DataPoint
                {
                    Label = label,

                    X = Enumerable.Repeat(label, 50)
                        .Select(x => x ? randomFloat() : randomFloat() +
                        0.1f).ToArray(),

                    Y = Enumerable.Repeat(label, 50)
                        .Select(x => x ? randomFloat() : randomFloat() +
                        0.1f).ToArray()
                };
            }
        }
    }
}
antoniovs1029 commented 4 years ago

Now, here I have an updated version of your original code, which fixes the problem following the principles on the sample above.

using System;
using System.Linq;
using System.IO;
using Microsoft.ML;
using System.Collections.Generic;

using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

class Program
{
    static bool STUDY_CONTINUE_MODE = true;

    const int MEM_SIZE = 9;

    static Random rand = new Random();

    static int total_study = 0;

    static List<ViewItem> happymem = new List<ViewItem>();
    static List<ViewItem> dangermem = new List<ViewItem>();
    static int GetTotalMemCount()
    {
        return happymem.Count + dangermem.Count;
    }

    static string feeling = new string('0', MEM_SIZE);

    static MLContext mlc = new MLContext();

    static long totalticks = 0;

    static long totalright = 0;
    static long totalwrong = 0;

    static void Main()
    {

        while (true)
        {
            RunTick();
            totalticks++;
            if (totalticks % 10000 == 0 && transf != null && totalright + totalwrong > 2)
            {
                Console.Title = totalticks + " - "
                    + " " + 100 * safe_mem / totalmem + "%"
                    + " CV=" + 100 * totalright / (totalright + totalwrong) + "%"
                    + " " + totalright + "/" + totalwrong
                    + " PV=" + 100 * predright / Math.Max(1, predtimes) + "%"
                    + " " + predright + "/" + (predtimes - predright)
                    + " HQ=" + Hao_Qi.ToString("00000")

                    + " ML=" + total_study + " " + (100 * dangermem.Count / GetTotalMemCount()) + "%" + " " + dangermem.Count + "/" + GetTotalMemCount()
                      ;
            }
        }

    }

    static ITransformer transf;
    static ITransformer datatransf;
    static LinearBinaryModelParameters lbmparameters = null;

    static void StudyOnce(ViewItem[] items)
    {
        var data = mlc.Data.LoadFromEnumerable(items);

        if (lbmparameters == null)
        {
            datatransf = mlc.Transforms.Text.FeaturizeText("Features", "TextData").Fit(data);

            if (!STUDY_CONTINUE_MODE)
                mlc = new MLContext();

            var newdata = datatransf.Transform(data);
            var mytransf = mlc.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(newdata);

            if (STUDY_CONTINUE_MODE)
                lbmparameters = mytransf.Model.SubModel;

            transf = mytransf;
        }
        else
        {
            //TODO:可以重新训练模型,但是在使用上存在问题.
            //因为使用了.Fit(newdata, lbmparameters)生成的Transformer不具备 mlc.Transforms.Text.FeaturizeText("Features", "TextData") 这个步骤.
            //这会导致 CreatePredictionEngine(transf) 生成的函数无法把ViewItem.TextData转换为Features 

            var newdata = datatransf.Transform(data);

            var mytransf = mlc.BinaryClassification.Trainers.LbfgsLogisticRegression().Fit(newdata, lbmparameters);

            if (STUDY_CONTINUE_MODE)
                lbmparameters = mytransf.Model.SubModel;

            transf = mytransf;
        }

    }

    static void StudyOnce()
    {
        ViewItem[] items = new ViewItem[GetTotalMemCount()];
        happymem.CopyTo(items, 0);
        dangermem.CopyTo(items, happymem.Count);
        StudyOnce(items);
    }

    class ViewItem
    {
        public string TextData;

        [ColumnName("Label")]
        public bool IsDanger;
    }
    class ViewPred
    {
        public bool PredictedLabel;

        [ColumnName("Score")]
        public float ScoreValue;

        public float Probability;
    }

    static int Hao_Qi = 100;

    static long totalmem = 0;
    static long safe_mem = 0;

    static bool feeling_isdanger;

    static void RunTick()
    {
        totalmem++;

        feeling = feeling.Remove(0, 1) + rand.Next(0, 10);

        feeling_isdanger = feeling.Contains("9");

        if (!feeling_isdanger)
            safe_mem++;

        Hao_Qi += 10;

        if (Hao_Qi > 100)
        {
            float danger = GetDangerLevel();
            if (Hao_Qi > 20000 * danger)
            {
                DoClick();
            }
        }

    }

    static long predtimes = 0;
    static long predright = 0;

    static PredictionEngine<ViewItem, ViewPred> pe;
    static ITransformer petransf;

    static float GetDangerLevel()
    {
        if (transf == null)
            return 0;

        if (pe == null || petransf != transf)
        {
            if (pe != null)
                pe.Dispose();//prevent memory leak ???
            pe = mlc.Model.CreatePredictionEngine<ViewItem, ViewPred>(datatransf.Append(transf));
            petransf = transf;
        }

        ViewPred pred = pe.Predict(new ViewItem() { TextData = feeling });

        bool isdanger = pred.PredictedLabel;

        predtimes++;
        if (isdanger == feeling_isdanger)
        {
            predright++;
        }

        return pred.Probability;
    }

    static void DoClick()
    {
        if (!feeling_isdanger)
        {
            Hao_Qi -= 40;
            totalright++;

            happymem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }
        else
        {
            Hao_Qi -= 90;
            totalwrong++;

            dangermem.Add(new ViewItem { TextData = feeling, IsDanger = feeling_isdanger });
        }

        if (happymem.Count > maxcount || dangermem.Count > maxcount)
        {
            total_study++;
            StudyOnce();

            if (happymem.Count > maxcount)
                happymem.RemoveRange(0, happymem.Count / 10);
            if (dangermem.Count > maxcount)
                dangermem.RemoveRange(0, dangermem.Count / 10);
            if (maxcount < 20000) maxcount = (int)(maxcount * 1.2);
        }

    }

    static int maxcount = 999;

}
antoniovs1029 commented 4 years ago

I will close this issue, since it wasn't really a bug. Feel free to reopen it if you're still having problems with the solution I suggested. Thanks.

volitantns commented 4 years ago

@antoniovs1029 thanks. It works. I am very happy to play ML.Net Now. thank you very much .