Flower Species Image Classifier - Image classifier to recognize different species of flowers
This project is a deep learning-based image classifier designed to recognize and classify different species of flowers. It leverages pre-trained convolutional neural networks (CNNs) like DenseNet121 and VGG19, fine-tuned on a flower dataset, to achieve high accuracy in species identification. The project includes both training and prediction scripts, making it suitable for end-to-end machine learning workflows.
- Flexible Model Architecture: Supports both DenseNet121 and VGG19 as backbone models.
- Customizable Training: Adjustable hyperparameters such as learning rate, dropout, hidden units, and epochs.
- GPU Support: Utilizes CUDA for accelerated training and inference if available.
- Comprehensive Data Augmentation: Includes resizing, cropping, rotation, and normalization for robust training.
- Prediction Script: Classifies single images and returns top-K probable classes with confidence scores.
- Model Saving and Loading: Saves trained models as checkpoints for future use.
Before running the project, ensure you have the following installed:
- Python 3.8+
- PyTorch 2.6.0 (with torchvision)
- Pillow (for image processing)
- NumPy
- argparse (for command-line arguments)
- json (for label mapping)
All dependencies are listed in requirements.txt
. Install them using:
pip install -r requirements.txt
- Clone the repository:
git clone https://github.com/CarlosYazid/Image-Classifier-Project.git cd Image-Classifier-Project
- Install dependencies:
pip install -r requirements.txt
- Download the dataset (ensure it follows the structure below):
data_dir/ ├── train/ ├── valid/ └── test/
Run the training script with optional arguments:
python train.py --data_dir ./flowers --save_dir ./models --arch densenet121 --gpu --epochs 20
Key Arguments:
--data_dir
: Path to the dataset directory.--save_dir
: Directory to save the trained model.--arch
: Model architecture (densenet121
orvgg19
).--gpu
: Enable GPU acceleration.--epochs
: Number of training epochs.
Use the trained model to classify an image:
python predict.py ./test_image.jpg ./models/checkpoint.pth --category_names cat_to_name.json --top_k 3 --gpu
Key Arguments:
image_path
: Path to the image to classify.save_path
: Path to the trained model checkpoint.--category_names
: JSON file mapping class indices to flower names.--top_k
: Number of top predictions to display.--gpu
: Use GPU for inference.
python train.py --data_dir ./flowers --arch vgg19 --learning_rate 0.0005 --hidden_units 512 --epochs 15 --gpu
This trains a VGG19 model with a custom learning rate, hidden layer size, and 15 epochs using GPU.
python predict.py ./rose.jpg ./models/checkpoint.pth --category_names cat_to_name.json --top_k 5
Output:
1. rose (98.7%)
2. hibiscus (0.8%)
3. tulip (0.3%)
4. sunflower (0.1%)
5. daisy (0.1%)
Image-Classifier-Project/
├── .gitignore
├── predict.py # Script for making predictions
├── train.py # Script for training the model
├── cat_to_name.json # Mapping of class indices to flower names
└── requirements.txt # Python dependencies
train_model
: Trains the model using specified hyperparameters.accuracy
: Evaluates model accuracy on the test set.save_model
: Saves the trained model as a checkpoint.
process_image
: Preprocesses an image for the model.load_model
: Loads a saved model checkpoint.predict
: Predicts the class of an input image.display_prediction
: Prints the top-K predictions.
- Fork the repository.
- Create a new branch for your feature or bugfix.
- Submit a pull request with a clear description of changes.
- Ensure all tests pass and adhere to the project's coding standards.
- CUDA Out of Memory: Reduce batch size or use a smaller model.
- Invalid Image Path: Ensure the path is correct and the image exists.
- JSON Decode Error: Verify
cat_to_name.json
is valid JSON.
- Added training and prediction scripts.
- Supported DenseNet121 and VGG19 architectures.
- Included GPU acceleration.
This project is licensed under the MIT License. See LICENSE for details.
For questions or support, please contact:
- Project Maintainer: Carlos Yazid
- Email: [email protected]
- GitHub Issues: Issues