Green Arabica Coffee Bean Classification (R50-V Hybrid Model)

Model Description

This model is a hybrid deep learning architecture combining ResNet-50 (CNN-based) and Vision Transformer for classifying green arabica coffee beans into four quality classes. The model achieved 91.88% accuracy on the USK-Coffee dataset, outperforming previous single-model approaches.

Model Type: Hybrid CNN-Transformer
Architecture: ResNet-50 + Vision Transformer (R50-V)
Task: Multi-class Image Classification
Classes: 4 (Peaberry, Longberry, Premium, Defect)

Model Architecture

The hybrid model processes input images through two parallel branches:

  1. CNN Branch (ResNet-50): Extracts local visual features such as texture, edges, and patterns
  2. Transformer Branch (Vision Transformer): Captures global context and long-range dependencies

Features from both branches are concatenated and passed through a classification head.

How to Use

Installation

pip install torch torchvision pillow huggingface_hub

Code

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import (
    resnet50, vit_b_16,
    ResNet50_Weights, ViT_B_16_Weights
)
from huggingface_hub import hf_hub_download

class HybridModel(nn.Module):
    def __init__(self, num_classes=4):
        super(HybridModel, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        resnet_in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()

        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        vit_in_features = self.vit.heads.head.in_features
        self.vit.heads = nn.Identity()

        self.fc = nn.Sequential(
            nn.Linear(resnet_in_features + vit_in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x_resnet, x_vit = x
        local_features = self.resnet(x_resnet)
        global_features = self.vit(x_vit)
        combined_features = torch.cat((local_features, global_features), dim=1)
        return self.fc(combined_features)

def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_path = hf_hub_download(
        repo_id="kelvinandreas/Green_Arabica_Coffee_Bean_Classification_R50-V",
        filename="best_model.pth"
    )

    model = HybridModel(num_classes=4)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    return model, device

def preprocess_image(image_path, device):
    image = Image.open(image_path).convert("RGB")

    preprocess_resnet = ResNet50_Weights.IMAGENET1K_V1.transforms()
    preprocess_vit = ViT_B_16_Weights.IMAGENET1K_V1.transforms()

    x_resnet = preprocess_resnet(image).unsqueeze(0).to(device)
    x_vit = preprocess_vit(image).unsqueeze(0).to(device)

    return x_resnet, x_vit

def predict(image_path):
    CLASS_LABELS = ['Defect', 'Longberry', 'Peaberry', 'Premium']

    model, device = load_model()
    x_resnet, x_vit = preprocess_image(image_path, device)

    with torch.no_grad():
        logits = model((x_resnet, x_vit))

    probs = F.softmax(logits, dim=1)
    pred_idx = torch.argmax(probs, dim=1).item()
    confidence = probs[0][pred_idx].item()

    return {
        "prediction": CLASS_LABELS[pred_idx],
        "confidence_percent": round(confidence * 100, 2)
    }

image_path = "your_coffee_bean_image.jpg"
result = predict(image_path)
print(result)

Evaluation Results

Metric Score
Accuracy 91.88%
Precision 92.12%
Recall 91.88%
F1-Score 91.85%
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support