# segment **Repository Path**: damone/segment ## Basic Information - **Project Name**: segment - **Description**: No description available - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-10-10 - **Last Updated**: 2025-10-11 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # ADE20K Semantic Segmentation Experiment Semantic segmentation experiment based on segmentation-models-pytorch with PyTorch Lightning framework on ADEChallengeData2016 dataset. ## Project Structure ``` segment/ ├── config/ # Configuration files for different models │ ├── dataset.py # Dataset configuration │ ├── segformer.py # Segformer model configuration │ ├── unet.py # Unet model configuration │ ├── fpn.py # FPN model configuration │ ├── linknet.py # Linknet model configuration │ └── pspnet.py # PSPNet model configuration ├── config_loader.py # Configuration loader ├── dataset.py # Dataset and data loader ├── model.py # PyTorch Lightning model definition ├── train.py # Training script with PyTorch Lightning ├── evaluate.py # Evaluation script ├── predict.py # Prediction script ├── requirements.txt # Dependencies ├── checkpoints_segformer/ # Model checkpoints ├── logs_segformer/ # Training logs └── README.md # Project documentation ``` ## Environment Requirements - Python 3.8+ - PyTorch 1.9+ - PyTorch Lightning 2.0+ - CUDA (recommended for GPU acceleration) - segmentation-models-pytorch ## Install Dependencies ```bash pip install -r requirements.txt ``` ## Dataset Preparation ### Available Datasets #### 1. ADE20K Dataset (Large Scale) ADEChallengeData2016 dataset will be automatically downloaded and prepared. The dataset contains: - 20,210 training images - 2,000 validation images - 150 semantic classes #### 2. CamVid Dataset (Small Real Dataset - Perfect for Validation) CamVid is a small real dataset with 701 images, ideal for validating the complete workflow: - 701 total images (road scene segmentation) - 32 semantic classes - Small file size and fast processing - Perfect for testing and validation workflows To use CamVid dataset: ```bash # Download and prepare CamVid dataset python download_camvid.py # Validate dataset setup python test_camvid.py ``` ## Usage ### 1. Train Model **Basic training:** ```bash python train.py --config config/segformer.py ``` **Training with specific epochs:** ```bash python train.py --config config/segformer.py --epochs 50 ``` **Fast development run (for testing):** ```bash python train.py --config config/segformer.py --fast_dev_run ``` **Resume training from checkpoint:** ```bash python train.py --config config/segformer.py --resume checkpoints_segformer/last.ckpt ``` ### 2. Evaluate Model ```bash python evaluate.py --config config/segformer.py ``` Evaluation metrics include: - Pixel Accuracy - Mean IoU (Intersection over Union) - Per-class IoU scores - Confusion matrix - Classification report ### 3. Model Prediction **Single image prediction:** ```bash python predict.py --config config/segformer.py --image path/to/image.jpg --output predictions/ ``` **Batch image prediction:** ```bash python predict.py --config config/segformer.py --image_dir path/to/images --output predictions/ ``` ## Configuration Options ### Available Model Configurations - `config/segformer.py` - Segformer transformer-based model - `config/unet.py` - Classic U-Net architecture - `config/fpn.py` - Feature Pyramid Network - `config/linknet.py` - Lightweight Linknet - `config/pspnet.py` - Pyramid Scene Parsing Network - `config/camvid.py` - **CamVid dataset configurations** (UNet and Segformer optimized for small dataset) ### Training Configuration (in config files) ```python MODEL_NAME = "Segformer" # Model architecture ENCODER_NAME = "mit_b0" # Encoder NUM_CLASSES = 150 # Number of classes BATCH_SIZE = 8 # Batch size NUM_EPOCHS = 100 # Number of training epochs LEARNING_RATE = 0.001 # Learning rate IMAGE_SIZE = (512, 512) # Input image size ``` ### Data Path Configuration ```python DATA_DIR = "/mnt/usb-data/project/segment/data" CHECKPOINT_DIR = "/mnt/usb-data/project/segment/checkpoints_segformer" LOG_DIR = "/mnt/usb-data/project/segment/logs_segformer" ``` ## Supported Model Architectures - **Segformer**: Transformer-based architecture with efficient design - **Unet**: Classic encoder-decoder architecture - **Linknet**: Lightweight architecture with high computational efficiency - **FPN**: Feature Pyramid Network - **PSPNet**: Pyramid Scene Parsing Network ## Supported Encoders - **Segformer encoders**: mit_b0, mit_b1, mit_b2, mit_b3, mit_b4, mit_b5 - **ResNet**: resnet18, resnet34, resnet50, resnet101, resnet152 - **EfficientNet**: efficientnet-b0 to efficientnet-b7 - More encoders can be found in segmentation-models-pytorch documentation ## PyTorch Lightning Features This project now uses PyTorch Lightning for training, providing: ### Enhanced Training Progress - Real-time loss monitoring every 10 batches - Detailed epoch summaries with training/validation loss - Loss ratio tracking to detect overfitting - Progress bar with frequent updates ### Automatic Features - Mixed precision training support - Multi-GPU training - Automatic checkpointing - Early stopping - Learning rate scheduling - Gradient accumulation ### Logging and Monitoring - CSV logging for training metrics - Model checkpointing with best model selection - Validation interval control - Training visualization ## Experimental Results After training completes, the following files will be generated: ### Checkpoints Directory - Best model checkpoint (`epoch=XX-val_loss=X.XXXX.ckpt`) - Last model checkpoint (`last.ckpt`) - Training progress checkpoints ### Logs Directory - Training metrics in CSV format - Training loss plots (`training_plot.png`) - Detailed training statistics ### Output Directory (Predictions) - Predicted masks for input images - Visualization images comparing original and predicted masks ## File Description ### config_loader.py Configuration loader that manages different model configurations and creates necessary directories. ### dataset.py Defines ADE20K dataset class and data loader, including data preprocessing and data augmentation. ### model.py **PyTorch Lightning Model** - Defines `SegmentationLightningModel` class with: - Forward pass implementation - Training, validation, and test steps - Loss calculation and metric computation - Optimizer and scheduler configuration ### train.py **Enhanced Training Script** with: - PyTorch Lightning trainer setup - Custom loss progress callback for detailed monitoring - Model checkpointing and early stopping - Training result visualization ### evaluate.py Model evaluation script with: - Automatic checkpoint loading - Detailed metric calculation (accuracy, IoU, confusion matrix) - Prediction visualization - Performance analysis ### predict.py Prediction script supporting: - Single image and batch image prediction - Automatic image preprocessing - Prediction visualization and saving ## New Features ### Real-time Loss Monitoring During training, you'll see detailed progress information: ``` === Epoch 1/100 === Batch 0/50 - Loss: 2.3456 Batch 10/50 - Loss: 1.2345 Training Loss: 1.1234 Validation Loss: 1.2345 Epoch Summary: Train Loss: 1.1234 Val Loss: 1.2345 Loss Ratio (Val/Train): 1.10 Improvement: -0.0123 (better) ``` ### Enhanced Checkpoint Management - Automatic best model selection - Resume training from any checkpoint - Multiple checkpoint saving strategies ## Notes 1. **Dataset**: ADE20K dataset (approximately 1.2GB) will be automatically downloaded on first run 2. **Storage**: Ensure sufficient disk space for dataset, checkpoints, and logs 3. **GPU**: Recommended for training to achieve better performance 4. **Memory**: Adjust `BATCH_SIZE` or `IMAGE_SIZE` to fit memory constraints ## Troubleshooting ### Common Issues 1. **Insufficient memory**: Reduce `BATCH_SIZE` or `IMAGE_SIZE` 2. **Dataset download failure**: Check network connection 3. **CUDA memory error**: Reduce batch size or use CPU training 4. **Checkpoint not found**: Ensure training has completed successfully ### Manual Dataset Download If automatic download fails: 1. Visit http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip 2. Extract to `data/` directory 3. Reorganize file structure to match configuration paths ## License This project is based on MIT License. ADE20K dataset follows its original license. ## Recent Updates - **Migrated to PyTorch Lightning** for more robust training framework - **Added real-time loss monitoring** with custom callback - **Enhanced evaluation and prediction** scripts - **Improved configuration management** with multiple model support - **Better checkpoint handling** and resume training capability