File size: 2,584 Bytes
3560106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""
Multi-Domain Classifier - Inference Example
Repository: https://huggingface.co/ovinduG/multi-domain-classifier-phi3
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import json
class MultiDomainClassifier:
def __init__(self, model_id="ovinduG/multi-domain-classifier-phi3"):
print("Loading model...")
# Load base model
self.base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(self.base_model, model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model.eval()
print("✅ Model loaded!")
def predict(self, query: str) -> dict:
"""Classify a query into domains"""
prompt = f"""Classify this query: {query}
Output JSON format:
{
"primary_domain": "domain_name",
"primary_confidence": 0.95,
"is_multi_domain": true/false,
"secondary_domains": [{"domain": "name", "confidence": 0.85}]
}"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.1,
do_sample=False,
use_cache=False
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Parse JSON from response
try:
json_str = response.split("Output JSON format:")[-1].strip()
result = json.loads(json_str)
return result
except:
return {"error": "Failed to parse response", "raw": response}
# Example usage
if __name__ == "__main__":
# Initialize classifier
classifier = MultiDomainClassifier()
# Example queries
queries = [
"Write a Python function to calculate factorial",
"Build ML model to analyze sales data and create API endpoints",
"What is quantum entanglement?",
"Create a REST API for healthcare diabetes prediction"
]
print("\n" + "="*80)
print("CLASSIFICATION EXAMPLES")
print("="*80)
for query in queries:
print(f"\nQuery: {query}")
result = classifier.predict(query)
print(f"Result: {json.dumps(result, indent=2)}")
print("-"*80)
|