Koushim commited on
Commit
8ea8763
·
verified ·
1 Parent(s): 2208b4a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -16
README.md CHANGED
@@ -3,35 +3,101 @@ license: mit
3
  tags:
4
  - medical
5
  ---
6
- # ViT Brain Tumor Classifier 🧠
7
 
8
- A ViT model trained on 75k brain MRI images using PyTorch + TIMM.
9
 
10
- ## Labels
11
- - Glioma
12
- - Meningioma
13
- - Tumor (general)
14
 
15
- ## Dataset
16
- [Brain Cancer MRI Dataset (2024)](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset)
17
 
18
- ## How to Use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  ```python
21
- from PIL import Image
22
- from torchvision import transforms
23
  import torch
24
- import timm
 
25
 
26
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
 
27
  model.load_state_dict(torch.load("pytorch_model.bin"))
28
  model.eval()
29
 
30
- img = Image.open("your_image.jpg")
31
  transform = transforms.Compose([
32
  transforms.Resize((224, 224)),
33
  transforms.ToTensor(),
34
  transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
35
  ])
36
- input_tensor = transform(img).unsqueeze(0)
37
- pred = model(input_tensor).argmax(dim=1)
 
 
 
 
 
 
3
  tags:
4
  - medical
5
  ---
6
+ # 🧠 Brain Tumor Classification Using Vision Transformer (ViT)
7
 
8
+ This repository contains a fine-tuned **Vision Transformer (ViT)** model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories:
9
 
10
+ - **Glioma**
11
+ - **Meningioma**
12
+ - **Tumor (General)**
 
13
 
14
+ The dataset used includes over **75,000 color-enhanced MRI images**, making this model highly capable for research and educational applications in brain tumor detection.
 
15
 
16
+ ---
17
+
18
+ ## 📊 Dataset Information
19
+
20
+ - **Original Dataset Name**: Brain Cancer - MRI dataset
21
+ - **Author**: Rahman, Md Mizanur (2024)
22
+ - **Hosted on**: [Mendeley Data](https://data.mendeley.com/datasets/mk56jw9rns/1)
23
+ - **DOI**: [10.17632/mk56jw9rns.1](https://doi.org/10.17632/mk56jw9rns.1)
24
+ - **Kaggle Rehost (Colorized)**: [Shuvo Kumar Basak on Kaggle](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset)
25
+
26
+ > **Note:** This dataset is publicly available for non-commercial research use. The model does not include the dataset itself.
27
+
28
+ ---
29
+
30
+ ## 🧠 Model Architecture
31
+
32
+ - **Model Type**: Vision Transformer (ViT-B/16)
33
+ - **Framework**: PyTorch + [timm](https://github.com/huggingface/pytorch-image-models)
34
+ - **Input Shape**: 224x224 RGB
35
+ - **Number of Classes**: 3
36
+ - **Loss Function**: CrossEntropyLoss
37
+ - **Optimizer**: AdamW
38
+
39
+ ---
40
+
41
+ ## 🏁 Training Pipeline Summary
42
+
43
+ 1. **Image Preprocessing**:
44
+ - Resize to 224x224
45
+ - Normalization using ImageNet stats
46
+ - Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc.
47
+
48
+ 2. **DataLoader**:
49
+ - Stratified Split (Train/Val/Test)
50
+ - PyTorch `Dataset` and `DataLoader` classes
51
+
52
+ 3. **Model**:
53
+ - Loaded ViT using `timm.create_model('vit_base_patch16_224', pretrained=True)`
54
+ - Modified the classifier head to match 3 output classes
55
+
56
+ 4. **Training**:
57
+ - Trained using mixed precision (`torch.cuda.amp`)
58
+ - Tracked using `tqdm`
59
+
60
+ 5. **Saving**:
61
+ - Model saved as `pytorch_model.bin`
62
+ - Configuration saved as `config.json`
63
+
64
+ ---
65
+
66
+ ## 🔍 Intended Use
67
+
68
+ This model is designed for:
69
+
70
+ - Educational purposes (deep learning and medical imaging)
71
+ - Research in brain tumor classification using transformers
72
+ - Demonstrating the power of ViT on colorized medical datasets
73
+
74
+ ⚠️ **Not intended for clinical use** or deployment without regulatory approval and further validation.
75
+
76
+ ---
77
+
78
+ ## 🚀 Inference Example (Python)
79
 
80
  ```python
81
+ from timm import create_model
 
82
  import torch
83
+ from torchvision import transforms
84
+ from PIL import Image
85
 
86
+ # Load model
87
+ model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
88
  model.load_state_dict(torch.load("pytorch_model.bin"))
89
  model.eval()
90
 
91
+ # Transform
92
  transform = transforms.Compose([
93
  transforms.Resize((224, 224)),
94
  transforms.ToTensor(),
95
  transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
96
  ])
97
+
98
+ # Inference
99
+ image = Image.open("example_mri.jpg").convert("RGB")
100
+ tensor = transform(image).unsqueeze(0)
101
+ output = model(tensor)
102
+ pred = torch.argmax(output, dim=1)
103
+ print("Predicted class:", pred.item())