This repository contains an implementation of a Vision Transformer (ViT) research paper tiitle "AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE" from scratch using PyTorch .
The project is organized into separate modules for better readability and maintainability, following best practices.
vision_transformer_project/
│
├── vision_transformer/
│ ├── __init__.py
│ ├── vision_transformer.py
│ ├── mlp_head.py
│ ├── transformer_encoder.py
│ ├── layer_norm.py
│ └── utils.py
│
├── main.py
└── requirements.txt
=======
- Vision Transformer: The main Vision Transformer class.
- MLP Head: The Multi-Layer Perceptron head for classification.
- Transformer Encoder: The Transformer Encoder layer.
- Normalization Layer: The Transformer Normalization layer.
- Utils: Utility functions for image processing and patch embedding.
-
Clone the repository:
git clone https://github.com/SYED-M-HUSSAIN/Implement-ViT-from-Scratch.git cd Implement-ViT-from-Scratch
-
Create a virtual environment and activate it:
python -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate`
-
Install the required packages:
pip install -r requirements.txt
-
Place your image file (e.g.,
Image.png
) in the project directory. -
Run the main script:
python main.py
The Vision Transformer class is defined in vision_transformer/vision_transformer.py
. It includes methods for patch embedding, positional encoding, and transformer encoder layers.
The MLP Head class is defined in vision_transformer/mlp_head.py
. It is used for the final classification task.
The Transformer Encoder class is defined in vision_transformer/transformer_encoder.py
. It includes multi-head attention and feed-forward layers.
The Transformer Layer Normalization class is defined in vision_transformer/layer_norm.py
.
Utility functions for image processing and patch embedding are defined in vision_transformer/utils.py
.
Here is an example of how to use the Vision Transformer:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from vision_transformer import VisionTransformer, image_to_patches, get_patch_embeddings
def main():
IMAGE_SIZE = 224
CHANNEL_SIZE = 3
NUM_CLASSES = 10
DROPOUT_PROB = 0.1
NUM_LAYERS = 12
EMBEDDING_DIM = 768
NUM_HEADS = 12
HIDDEN_DIM = 3072
PATCH_SIZE = 16
IMAGE_PATH = 'Image.png'
image_patches = image_to_patches(IMAGE_PATH, PATCH_SIZE)
patch_embeddings = get_patch_embeddings(image_patches, EMBEDDING_DIM)
vision_transformer = VisionTransformer(
patch_size=PATCH_SIZE,
image_size=IMAGE_SIZE,
channel_size=CHANNEL_SIZE,
num_layers=NUM_LAYERS,
embedding_dim=EMBEDDING_DIM,
num_heads=NUM_HEADS,
hidden_dim=HIDDEN_DIM,
dropout_prob=DROPOUT_PROB,
num_classes=NUM_CLASSES
)
vit_output = vision_transformer(patch_embeddings)
print(vit_output.shape)
probabilities = F.softmax(vit_output[0], dim=0)
print(probabilities)
print(torch.sum(probabilities))
if __name__ == "__main__":
main()
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
- The implementation is inspired by the Vision Transformer (ViT) paper by Dosovitskiy et al, Published as a conference paper at ICLR 2021