Week 4: Pretraining: March, MedMNISTv2, and everything in between
Hi everyone! Hope you all have had a great start to the beginning of March! As March is usually the beginning of flowers blooming, for this experiment, it is the start of model training!
To start, I’m excited to introduce you to MedMNISTv2, a groundbreaking development in the field of biomedical image analysis. MedMNISTv2 is a large-scale MNIST-like collection of standardized biomedical images, offering a diverse range of datasets for both 2D and 3D images. These datasets have been pre-processed into 28×28 (2D) or 28x28x28 (3D) formats, accompanied by corresponding classification labels, making them incredibly accessible even for users with minimal background knowledge. This initiative aims to democratize research in biomedical image analysis by providing a standardized platform for classification tasks on lightweight images across various scales and tasks, including binary/multi-class, ordinal regression, and multi-label classifications. With approximately 708K 2D images and 10K 3D images, MedMNISTv2 is poised to support numerous research and educational endeavors in biomedical image analysis, computer vision, and machine learning.
Furthermore, the recent release of MedMNIST+, featuring larger image sizes such as 64×64, 128×128, 224×224 for 2D, and 64x64x64 for 3D, further enhances its utility as a standardized benchmark for developing medical foundation models. MedMNISTv2 is designed to be user-friendly, standardized, and educational, providing researchers from various disciplines with an accessible and comprehensive resource for advancing their work in biomedical image analysis. If you are curious to learn more about this dataset, check it out here: https://medmnist.com/
Now, you might be wondering: Andrew, how is this related to the work you’ve been doing to stable diffusion? Well, this dataset usually serves as a sanity check or a baseline for training different models, especially in the medical field. In other words, this dataset is like the starter point to help train simple to advanced models to identify medical images.
To make sure this dataset is functional for training models like my stable diffusion model, RoentGen, I used a simple Convolutional neural network (CNN) model, which contains 5 convolutional layers and 3 fully connected layers. More specifically, in this format.
Convolutional Layers (5 layers):
- Layer1:
- Convolution: 16 output channels, kernel size 3×3.
- Layer 2:
- Convolution: 16 input, 16 output channels, kernel size 3×3.
- Batch Normalization
- ReLU Activation
- Max Pooling: kernel size 2×2, stride 2.
- Layer 3:
- Convolution: 16 input, 64 output channels, kernel size 3×3.
- Batch Normalization
- ReLU Activation
- Layer 4:
- Convolution: 64 input, 64 output channels, kernel size 3×3.
- Batch Normalization
- ReLU Activation
- Layer 5:
- Convolution: 64 input, 64 output channels, kernel size 3×3, padding 1.
- Batch Normalization
- ReLU Activation
- Max Pooling: kernel size 2×2, stride 2.
Fully Connected Layers (3 layers):
- Linear: input size 64 * 4 * 4, output size 128.
- ReLU Activation
- Linear: input size 128, output size 128.
- ReLU Activation
- Linear: input size 128, output size num_classes.
Now, you might be wondering, what does any of that mean? Well, I’ve provided a table that gives a better understanding of what each of those terms mean.
Term | Explanation |
Convolutional Layer | A layer that learns to recognize patterns in pictures. |
Batch Normalization | Helps make training faster and more stable. |
ReLU Activation | Make sure the network only picks out positive patterns. |
Max Pooling | Helps shrink down the picture, focusing on important parts. |
Fully Connected Layers
Bottleneck Input Image Resolution |
A layer that takes all the patterns and figures out what they mean together.
A point where the network squeezes down the information. The size of the picture the network looks at. |
With this in mind, I used this CNN model and the MedMNISTv2 dataset to train and test for accuracy metrics. With this method, I achieved results around 0.80 in AUC (Area Under the Curve — accuracy metric that measures how well a model can distinguish between different categories) and 0.60 in ACC (proportion of questions a model answers correctly from the test dataset.), which is not bad when running a simple training model given limited resources and layer amount.
Next week, I plan to take this model and take it up a notch to the residual network, which is one step closer to running the stable diffusion model using the MedMNISTv2 dataset. Also, I do plan to address any questions if any of y’all are confused about what model training is. Until then, stay tuned!
Best,
Andrew