File size: 2,584 Bytes
3560106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Multi-Domain Classifier - Inference Example
Repository: https://huggingface.co/ovinduG/multi-domain-classifier-phi3
"""

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import json

class MultiDomainClassifier:
    def __init__(self, model_id="ovinduG/multi-domain-classifier-phi3"):
        print("Loading model...")
        
        # Load base model
        self.base_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Load LoRA adapter
        self.model = PeftModel.from_pretrained(self.base_model, model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model.eval()
        
        print("✅ Model loaded!")
    
    def predict(self, query: str) -> dict:
        """Classify a query into domains"""
        
        prompt = f"""Classify this query: {query}

Output JSON format:
{
  "primary_domain": "domain_name",
  "primary_confidence": 0.95,
  "is_multi_domain": true/false,
  "secondary_domains": [{"domain": "name", "confidence": 0.85}]
}"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.1,
                do_sample=False,
                use_cache=False
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Parse JSON from response
        try:
            json_str = response.split("Output JSON format:")[-1].strip()
            result = json.loads(json_str)
            return result
        except:
            return {"error": "Failed to parse response", "raw": response}


# Example usage
if __name__ == "__main__":
    # Initialize classifier
    classifier = MultiDomainClassifier()
    
    # Example queries
    queries = [
        "Write a Python function to calculate factorial",
        "Build ML model to analyze sales data and create API endpoints",
        "What is quantum entanglement?",
        "Create a REST API for healthcare diabetes prediction"
    ]
    
    print("\n" + "="*80)
    print("CLASSIFICATION EXAMPLES")
    print("="*80)
    
    for query in queries:
        print(f"\nQuery: {query}")
        result = classifier.predict(query)
        print(f"Result: {json.dumps(result, indent=2)}")
        print("-"*80)