# VAE-Pytorch **Repository Path**: chich2007/VAE-Pytorch ## Basic Information - **Project Name**: VAE-Pytorch - **Description**: No description available - **Primary Language**: Unknown - **License**: MIT - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2024-09-08 - **Last Updated**: 2025-02-08 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README VAE Implementation in pytorch with visualizations ======== This repository implements a simple VAE for training on CPU on the MNIST dataset and provides ability to visualize the latent space, entire manifold as well as visualize how numbers interpolate between each other. The purpose of this project is to get a better understanding of VAE by playing with the different parameters and visualizations. ## VAE Tutorial Videos VAE Understanding Implementing VAE ## Architecture # Quickstart * Create a new conda environment with python 3.8 then run below commands * ```git clone https://github.com/explainingai-code/Pytorch-VAE.git``` * ```cd Pytorch-VAE``` * ```pip install -r requirements.txt``` * For running a simple fc layer backed VAE with latent dimension as 2 run ```python run_simple_vae.py``` * For playing around with VAE and running visualizations, replace tools/train_vae.py and tools/inference.py config argument with the desired one or pass that in the next set of commands * ```python -m tools.train_vae``` * ```python -m tools.inference``` ## Configurations * ```config/vae_nokl.yaml``` - VAE with only reconstruction loss * ```config/vae_kl.yaml``` - VAE with reconstruction and KL loss * ```config/vae_kl_latent4.yaml``` - VAE with reconstruction and KL loss with latent dimension as 4(instead of 2) * ```config/vae_kl_latent4_enc_channel_dec_fc_condition.yaml``` - Conditional VAE with reconstruction and KL loss with latent dimension as 4 ## Data preparation We don't use the torchvision mnist dataset to allow replacement with any other image dataset. For setting up the dataset: * Create ```data/train/images``` and ```data/test/images``` folders * Download the csv files for mnist(https://www.kaggle.com/datasets/oddrationale/mnist-in-csv) and save them under ```data```directory. * Run ```python utils/extract_mnist_images.py``` Verify the data directory has the following structure: ``` Pytorch-VAE/data/train/images/{0/1/.../9} *.png Pytorch-VAE/data/test/images/{0/1/.../9} *.png ``` ## Output Outputs will be saved according to the configuration present in yaml files. For every run a folder of ```task_name``` key in config will be created and ```output_train_dir``` will be created inside it. During training the following output will be saved * Best Model checkpoints in ```task_name``` directory * PCA information in pickle file in ```task_name``` directory * 2D Latent space plotting the images of test set for each epoch in ```task_name/output_train_dir``` directory During inference the following output will be saved * Reconstructions for sample of test set in ```task_name/output_train_dir/reconstruction.png``` * Decoder output for sample of points evenly spaced across the projection of latent space on 2D in ```task_name/output_train_dir/manifold.png``` * Interpolation between two randomly sampled points in ```task_name/output_train_dir/interp``` directory ## Sample Output for VAE Latent Visualization Manifold Reconstruction Images(reconstruction in black font and original in white font) ## Sample Output for Conditional VAE Because we end up passing the label to the decoder, the model ends up learning the capability to generate ALL numbers from all points in the latent space. The model will learn to distinguish points in latent space based on if it should generate a left or right tilted digit or how thick the stroke for digit should be. Below one can visulize those patterns when we attempt to generate all numbers from all points. Reconstruction Images(reconstruction in black font and original in white font)