abdelaziz-mahdy / pytorch_lite

flutter package to help run pytorch lite models classification and YoloV5 and YoloV8.
MIT License
53 stars 23 forks source link

Quantized Custom Image Classification Model #86

Open ZaZra03 opened 1 day ago

ZaZra03 commented 1 day ago

I have a fine-tuned EfficientNet Lite0 model from timm, which I quantized using QAT FX Graph Mode in PyTorch. When I test images containing the same object in both PyTorch and Dart, I get different classification results. In Dart, I always receive a single classification label. Also, do these packages automatically handle image preprocessing, such as resizing and normalization?

[https://colab.research.google.com/drive/15z8RkL-bNaQe0TwRzEXNHbb1ulG17dpI?usp=sharing](this is the notebook where I trained my model in colab.)

abdelaziz-mahdy commented 1 day ago

The lib handles the normalization and the resizing yes,

Please try your exported model in Python to see if it works correctly or not

Since someone has a problem with the exported model not working correctly

ZaZra03 commented 1 day ago

Hi @abdelaziz-mahdy, Thank you so much for your response. Could you please check my dart code here:

import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:pytorch_lite/pytorch_lite.dart';
import 'package:camera/camera.dart';
import 'package:provider/provider.dart';
import 'package:urecycle_app/constants.dart';
import 'package:urecycle_app/view/widget/loading_widget.dart';
import '../../../provider/user_provider.dart';
import '../../screen/user_screen.dart';

class Scan extends StatefulWidget {
  const Scan({super.key});

  @override
  State<Scan> createState() => _ScanState();
}

class _ScanState extends State<Scan> {
  CameraController? _cameraController;
  late String _classificationResult;
  bool _isProcessing = true;
  ClassificationModel? _classificationModel;

  @override
  void initState() {
    super.initState();
    WidgetsBinding.instance.addPostFrameCallback((_) async {
      await _initializeComponents();
    });
  }

  Future<void> _initializeComponents() async {
    await Future.wait([
      _initializeCamera(),
      _initializeUserProvider(),
      _loadModel(),
    ]);
    await _takePictureAndProcess();
  }

  Future<void> _initializeUserProvider() async {
    final userProvider = Provider.of<UserProvider>(context, listen: false);
    if (userProvider.user == null || userProvider.lbUser == null) {
      await userProvider.fetchUserData();
    }
  }

  Future<void> _loadModel() async {
    const String modelPath = "assets/models/efficientnet_lite0_quantized.pt";
    try {
      _classificationModel = await PytorchLite.loadClassificationModel(
        modelPath,
        224, 224, // This line has an issue
        labelPath: "assets/labels/model.txt",
      );
    } on PlatformException {
      print("Model loading is only supported on Android.");
    }
  }

  Future<void> _initializeCamera() async {
    try {
      final cameras = await availableCameras();
      _cameraController = CameraController(
        cameras.first,
        ResolutionPreset.high,
        enableAudio: false,
      );
      await _cameraController?.initialize();
    } catch (e) {
      _setProcessingState('Error initializing camera: $e');
    }
  }

  Future<void> _takePictureAndProcess() async {
    final pickedFile = await _pickImage();
    if (pickedFile == null) {
      _setProcessingState('No image selected.');
      return;
    }

    try {
      final imageBytes = await File(pickedFile.path).readAsBytes();
      _classificationResult = await _getClassificationResult(imageBytes);
      print(_classificationResult);
      _navigateToRecycleScreen();
    } catch (e) {
      _setProcessingState('Error during processing: $e');
    } finally {
      _setProcessingState('');
    }
  }

  Future<XFile?> _pickImage() async {
    final picker = ImagePicker();
    return await picker.pickImage(source: ImageSource.camera);
  }

  Future<String> _getClassificationResult(Uint8List imageBytes) async {
    // Define normalization parameters
    final mean = [0.485, 0.456, 0.406];
    final std = [0.229, 0.224, 0.225];

    // Get prediction with normalization
    final result = await _classificationModel!.getImagePrediction(imageBytes, mean: mean, std: std);

    print(result);
    return result;
  }

  void _navigateToRecycleScreen() {
    Navigator.push(
      context,
      MaterialPageRoute(
        builder: (context) => Recycle(result: _classificationResult),
      ),
    );
  }

  void _setProcessingState(String message) {
    setState(() {
      _isProcessing = false;
      _classificationResult = message;
    });
  }

  @override
  Widget build(BuildContext context) {
    return _isProcessing || _cameraController == null || !_cameraController!.value.isInitialized
        ? const LoadingPage()
        : Scaffold(
      body: Center(child: CameraPreview(_cameraController!)),
    );
  }

  @override
  void dispose() {
    _cameraController?.dispose();
    super.dispose();
  }
}

I also tested several images with the model in Colab, and here are the results: image image image

ZaZra03 commented 1 day ago

This is the result from the Dart code. I used the camera package to capture the image, which I will then pass to the model: i_B4EYtd qhdWC7b4

ZaZra03 commented 1 day ago

I have tried it several times in Dart, but it keeps predicting the image as paper.

abdelaziz-mahdy commented 1 day ago

Camera images have a different function, to use the camera image directly you can use these functions, and most probably that's the reason for the wrong classification

I will try to provide you with a fixed code if I got time

abdelaziz-mahdy commented 1 day ago

actually after adding an example page with same configuration it works correctly , so i am not sure where is the problem,

you can find my code here: https://github.com/abdelaziz-mahdy/pytorch_lite/blob/master/example/lib/run_model_by_image_picker_camera_demo.dart

abdelaziz-mahdy commented 1 day ago

also as i mentioned here can you try these https://github.com/abdelaziz-mahdy/pytorch_lite/issues/82#issuecomment-2408585385

and check this case https://github.com/abdelaziz-mahdy/pytorch_lite/issues/82#issuecomment-2411429918 maybe its a new problem with PyTorch optimization?

ZaZra03 commented 23 hours ago

Hi, I tried to skip the optimized_traced_model, but unfortunately, I get an error when I run it in the app. Also, how can I add preProcessingMethod to the getImagePredictionList method since the PyTorch Lite I'm using doesn't have a parameter for that?

  ///predicts image but returns the raw net output
  Future<List<double?>?> getImagePredictionList(Uint8List imageAsBytes,
      {List<double> mean = torchVisionNormMeanRGB,
      List<double> std = torchVisionNormSTDRGB}) async {
    // Assert mean std
    assert(mean.length == 3, "Mean should have size of 3");
    assert(std.length == 3, "STD should have size of 3");
    final List<double?>? prediction = await ModelApi().getImagePredictionList(
        _index, imageAsBytes, null, null, null, mean, std);
    return prediction;
  }

I also have problem regarding specifying the number of classes in loadClassificationModel:

  ///Sets pytorch model path and returns Model
  static Future<ClassificationModel> loadClassificationModel(
      String path, int imageWidth, int imageHeight,
      {String? labelPath}) async {
    String absPathModelPath = await _getAbsolutePath(path);
    int index = await ModelApi()
        .loadModel(absPathModelPath, null, imageWidth, imageHeight, null);
    List<String> labels = [];
    if (labelPath != null) {
      if (labelPath.endsWith(".txt")) {
        labels = await _getLabelsTxt(labelPath);
      } else {
        labels = await _getLabelsCsv(labelPath);
      }
    }

    return ClassificationModel(index, labels);
  }
abdelaziz-mahdy commented 22 hours ago

82 (comment)

can you explain the problem with the labels?

also are you sure you are using the latest version?

abdelaziz-mahdy commented 22 hours ago

Also can you provide the error you faced?

ZaZra03 commented 21 hours ago

Sorry, my mistake. im using an old version of this library. im working on fixing the dependency issue in my code and will update you soon.

abdelaziz-mahdy commented 21 hours ago

okay, let me know when you do it

ZaZra03 commented 2 hours ago

hi @abdelaziz-mahdy, i updated the library to the latest version, but im still getting the wrong results. i tried everything you suggested, like skipping optimize_for_mobile, but it makes no difference. i always get class 3 ("paper") as the result. i also added preProcessingMethod to getImagePredictionList, but no luck. Even tried setting mean and std arguments, but same result. Here’s my Dart code:

import 'dart:io';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:pytorch_lite/pytorch_lite.dart';

class Scan extends StatefulWidget {
  const Scan({super.key});

  @override
  State createState() => _ScanState();
}

class _ScanState extends State<Scan> {
  String? classificationResult;
  Duration? classificationInferenceTime;
  File? _image;
  ClassificationModel? _imageModel;
  bool _isLoading = false;

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    const pathImageModel = "assets/models/best_model.pt";
    try {
      _imageModel = await PytorchLite.loadClassificationModel(
        pathImageModel, 224, 224, 6,
        labelPath: "assets/labels/model.txt",
      );
    } catch (e) {
      print("Error loading model: $e");
    }
  }

  Future<void> runModels(ImageSource source) async {
    setState(() => _isLoading = true);

    final ImagePicker picker = ImagePicker();
    final XFile? pickedImage = await picker.pickImage(source: source);
    if (pickedImage == null) {
      setState(() => _isLoading = false);
      return;
    }

    final File image = File(pickedImage.path);
    final Uint8List imageBytes = await image.readAsBytes();

    try {
      final stopwatch = Stopwatch()..start();
      List<double?>? predictionList = await _imageModel!.getImagePredictionList(
        imageBytes,
      );
      classificationInferenceTime = stopwatch.elapsed;
      print(predictionList);
      int maxIndex = predictionList.indexWhere((e) => e == predictionList.reduce((a, b) => a! > b! ? a : b));
      classificationResult = maxIndex >= 0 ? "Class $maxIndex" : "N/A";

      setState(() {
        _image = image;
        classificationResult = classificationResult;
      });
    } catch (e) {
      print("Error during classification: $e");
    } finally {
      setState(() => _isLoading = false);
    }
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: const Text('Run Models')),
      body: Center(
        child: _isLoading
            ? const CircularProgressIndicator()
            : SingleChildScrollView(
          child: Column(
            mainAxisAlignment: MainAxisAlignment.center,
            children: [
              if (_image != null) ...[
                const SizedBox(height: 20),
                Image.file(_image!),
                const SizedBox(height: 20),
                Text(
                  "Classification Result: ${classificationResult ?? "N/A"}",
                  style: const TextStyle(fontSize: 16),
                ),
                Text(
                  "Classification Time: ${classificationInferenceTime?.inMilliseconds ?? "N/A"} ms",
                  style: const TextStyle(fontSize: 16),
                ),
                const SizedBox(height: 20),
              ],
              ElevatedButton(
                onPressed: () => runModels(ImageSource.camera),
                child: const Text('Take Photo & Run Models'),
              ),
              ElevatedButton(
                onPressed: () => runModels(ImageSource.gallery),
                child: const Text('Pick from Gallery & Run Models'),
              ),
            ],
          ),
        ),
      ),
    );
  }
}
ZaZra03 commented 2 hours ago

and this is the result: 6raiXsjo I/flutter (28163): [0.2844568192958832, 0.2844568192958832, -0.1422284096479416, 0.7111420631408691, -0.8533704280853271, 0.2844568192958832]

this is my colab link where i tested the model with similar images https://colab.research.google.com/drive/1dQQMVErSO9wU2JhkVJXGvX4_TzNjIY9y?usp=sharing

ZaZra03 commented 2 hours ago

I tried it multiple times, but I keep getting paper as the result.

abdelaziz-mahdy commented 2 hours ago

Did you test with my example? It works correctly for my example

So either the model is not correctly read by pytorch for a reason out of my control

Maybe the model is not correct

Or the image preprocessing is different from python and the package which doesn't look like that since it has the same mean and std

Without the model I don't have any way of making sure where the problem is, keep in mind each model is different so it's not the same for everything

Try to test the same mobile model in Python with the same image from the phone and see it's result