dotnet / machinelearning

ML.NET is an open source and cross-platform machine learning framework for .NET.
https://dot.net/ml
MIT License
9.02k stars 1.88k forks source link

KeyNotFoundException after Bottleneck Computation phase finishes while training image classification model #5157

Closed szymenn closed 4 years ago

szymenn commented 4 years ago

System information

Issue

Source code / logs

My Program.cs file:

`

 class Program
       {
       static void Main(string[] args)
        {
        var mlContext = new MLContext();

        var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "../../../"));
        var trainRelativePath = Path.Combine(projectDirectory, "Data\\train");
        var testRelativePath = Path.Combine(projectDirectory, "Data\\test1");
        var workspaceRelativePath = Path.Combine(projectDirectory, "Workspace");

        var trainImages = LoadImagesFromDirectory(trainRelativePath);
        var testImages = LoadImagesFromDirectory(testRelativePath);

        var trainData = mlContext.Data.LoadFromEnumerable(trainImages);
        var testData = mlContext.Data.LoadFromEnumerable(testImages);

        var shuffledData = mlContext.Data.ShuffleRows(trainData);            

        var preprocessingPipeline = mlContext.Transforms.Conversion.MapValueToKey(
                inputColumnName: "Label",
                outputColumnName: "LabelAsKey")
            .Append(mlContext.Transforms.LoadRawImageBytes(
                outputColumnName: "Image",
                imageFolder: trainRelativePath,
                inputColumnName: "ImagePath"));

        var preProcessedData = preprocessingPipeline
            .Fit(shuffledData)
            .Transform(shuffledData);

        var testProcessedData = preprocessingPipeline
            .Fit(testData)
            .Transform(testData);

        var trainProcessedSet = mlContext.Data.TrainTestSplit(preProcessedData, testFraction: 0.999f);

        var testSplit = mlContext.Data.TrainTestSplit(testProcessedData, testFraction: 0.999f);
        var validationSet = testSplit.TrainSet;
        var testSet = testSplit.TestSet;

        var classifierOptions = new ImageClassificationTrainer.Options()
        {
            FeatureColumnName = "Image",
            LabelColumnName = "LabelAsKey",
            ValidationSet = validationSet,
            Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
            MetricsCallback = Console.WriteLine,
            TestOnTrainSet = false,
            ReuseTrainSetBottleneckCachedValues = false,
            ReuseValidationSetBottleneckCachedValues = false,
            WorkspacePath=workspaceRelativePath,
            Epoch = 100
        };

        var trainingPipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(classifierOptions)
            .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

        var trainedModel = trainingPipeline.Fit(trainProcessedSet.TrainSet);
        mlContext.Model.Save(trainedModel, trainData.Schema,"Model.zip");

        Console.ReadLine();
    }

    private static IEnumerable<ImageData> LoadImagesFromDirectory(string folder)
    {
        var files = Directory.GetFiles(folder, "*",
            searchOption: SearchOption.AllDirectories);

        foreach (var file in files)
        {
                var label = Path.GetFileName(file).Split(".")[0];

                yield return new ImageData()
                {
                    ImagePath = file,
                    Label = label
                };
        }
    }
}

`

ImageData.cs: `

       public class ImageData
       {
                public string ImagePath { get; set; }

                public string Label { get; set; }
        }

`

ModelInput.cs: `

public class ModelInput
{
    public byte[] Image { get; set; }

    public UInt32 LabelAsKey { get; set; }

    public string ImagePath { get; set; }

    public string Label { get; set; }
}

ModelOutput.cs:

public class ModelOutput
{
    public string ImagePath { get; set; }

    public string Label { get; set; }

    public string PredictedLabel { get; set; }
}

`

Detailed Exception info: Unhandled exception. System.Collections.Generic.KeyNotFoundException: The given key '3711' was not present in the dictionary. at System.Collections.Generic.Dictionary 2.get_Item(TKey key) at Microsoft.ML.Vision.ImageClassificationTrainer.TrainAndEvaluateClassificationLayer(String trainBottleneckFilePath, Options options, String validationSetBottleneckFilePath, Int32 trainingsetSize) at Microsoft.ML.Vision.ImageClassificationTrainer.TrainModelCore(TrainContext trainContext) at Microsoft.ML.Trainers.TrainerEstimatorBase 2.TrainTransformer(IDataView trainSet, IDataView validationSet, IPredictor initPredictor) at Microsoft.ML.Trainers.TrainerEstimatorBase 2.Fit(IDataView input) at Microsoft.ML.Data.EstimatorChain 1.Fit(IDataView input) at ModelBuilder.Program.Main(String[] args) in C:\Users\User\RiderProjects\CatVsDogsBinaryClassification\ModelBuilder\Program.cs:line 71

szymenn commented 4 years ago

I found out that I got this error because I was using data from test set as a validation set instead of using training set, closing the issue