ramintoosi / cnn-rust

Train a CNN in Rust.
2 stars 0 forks source link

Questions about accuracy ? #1

Open tharyckgusmao opened 16 hours ago

tharyckgusmao commented 16 hours ago

Hello, my friend, how are you?

I came across your project while browsing on GitHub and found it incredible. I've been studying Rust and machine learning, and I'm a novice in the world of artificial intelligence. I can see that you have extensive experience in this field. I ran the CNN model on an ants vs bees dataset, the hymenoptera_data:

https://www.kaggle.com/datasets/ajayrana/hymenoptera-data

My question is not about how the code works—I was able to understand your approach quite well. However, I can’t figure out why I’m not getting high accuracy on validation. Do you have any ideas on how I can achieve better accuracy?

Thank you so much in advance!

Captura de tela de 2024-09-21 19-18-05

ramintoosi commented 7 hours ago

Hello! Thanks for reaching out and for your kind words about my project. I'm happy to hear that you're diving into Rust and machine learning—it's an exciting journey!

Regarding your question, if your training accuracy is hitting 100% but your validation accuracy is stuck at around 52%, it sounds like your model might be overfitting. Overfitting happens when a model learns the training data too well (including noise and minor details), but struggles to generalize to new, unseen data like your validation set.

Here are a few ideas to help improve your validation accuracy:

  1. Regularization: Try adding dropout layers in your model, which helps prevent overfitting by randomly dropping units during training. You can also consider L2 regularization (weight decay) to penalize large weights in your network.

  2. Data Augmentation: Since your dataset is small (as I checked in Kaggle), the model could be memorizing the training images rather than learning general features. Apply augmentation techniques like random rotations, flipping, or cropping to enlarge your dataset and introduce more variability artificially (checkout this code).

  3. Simplify the Model: A complex model with too many parameters might overfit the small dataset. Try reducing the number of layers or parameters, and see if it generalizes better.

  4. Learning Rate: Sometimes, the learning rate can be a bit too high or too low, which can cause the model to either miss the optimal point or get stuck—experiment with adjusting the learning rate to find a better fit.

  5. Pretrained Models: Since the dataset is relatively small, you might want to use transfer learning by fine-tuning a pre-trained model like ResNet or VGG. These models already have learned features from large-scale datasets (like ImageNet), which can be beneficial for similar tasks. Checkout this link.

Hope this helps! Feel free to share any more details or updates, and I’d be happy to dive deeper if needed.

Best of luck, and keep me posted on your progress!