ClaCe commited on
Commit
f0c6b6b
·
verified ·
1 Parent(s): afb1abf

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +87 -6
  2. app.py +1262 -0
  3. chroma_db_complete.tar.gz +3 -0
  4. requirements.txt +11 -0
README.md CHANGED
@@ -1,12 +1,93 @@
1
  ---
2
- title: FindHugForPMwithKey
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ML Use Cases RAG Assistant (BYOK)
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # ML/AI Use Cases RAG Assistant (Bring Your Own Key)
14
+
15
+ An AI-powered assistant that provides business advice based on real ML/AI implementations from 310+ companies. This app uses Retrieval-Augmented Generation (RAG) to find relevant company examples and provides actionable recommendations.
16
+
17
+ **🔑 Bring Your Own Key:** This version requires users to provide their own HuggingFace API key, ensuring zero cost to the space owner while maintaining full functionality.
18
+
19
+ ## Features
20
+
21
+ - **🔑 BYOK (Bring Your Own Key)**: Use your own HuggingFace API key for secure, cost-effective access
22
+ - **🔍 Semantic Search**: Find relevant ML/AI use cases from a comprehensive database
23
+ - **🤖 AI-Powered Advice**: Get personalized recommendations using HuggingFace Inference API
24
+ - **📊 Model Recommendations**: Discover fine-tuned and foundation models for your specific use case
25
+ - **🏢 Real Company Examples**: Learn from actual implementations across various industries
26
+ - **🔒 Privacy-First**: Only embeddings are used - no raw company data is exposed
27
+ - **💰 Zero Cost to Owner**: No API costs for the space owner - users bring their own keys
28
+
29
+ ## How It Works
30
+
31
+ 1. **🔑 API Key Setup**: Provide your HuggingFace API key for secure access
32
+ 2. **📝 Query Processing**: Your business problem is analyzed and converted to embeddings
33
+ 3. **🔍 Semantic Search**: The system searches through 310+ pre-processed ML use cases
34
+ 4. **📚 Context Building**: Relevant company examples are selected as context
35
+ 5. **🤖 AI Generation**: Your API key powers the language model to generate tailored advice
36
+ 6. **📊 Model Matching**: HuggingFace API provides relevant model recommendations using your key
37
+
38
+ ## Technology Stack
39
+
40
+ - **Backend**: FastAPI with async support and BYOK architecture
41
+ - **Vector Database**: ChromaDB for semantic search
42
+ - **Embeddings**: Sentence Transformers (all-MiniLM-L6-v2)
43
+ - **Language Model**: HuggingFace Inference API (Gemma 2 2B with fallbacks)
44
+ - **Frontend**: Modern HTML/CSS/JavaScript with Tailwind CSS
45
+ - **Security**: User API keys never stored, used only for requests
46
+
47
+ ## Security & Privacy
48
+
49
+ - **🔐 API Key Security**: Your API key is never stored permanently, only used for requests
50
+ - **📊 No Raw Data**: Only vector embeddings and metadata are stored
51
+ - **🏢 Company Privacy**: Original datasets remain private
52
+ - **🛡️ Secure Processing**: All processing happens within the secure HuggingFace environment
53
+ - **💾 Local Storage**: API keys stored locally in your browser for convenience
54
+
55
+ ## Getting Started
56
+
57
+ ### 1. Get Your HuggingFace API Key
58
+ 1. Visit [HuggingFace Settings](https://huggingface.co/settings/tokens)
59
+ 2. Click "Create new token"
60
+ 3. Select "Read" access (sufficient for this app)
61
+ 4. Copy your token (starts with `hf_`)
62
+
63
+ ### 2. Use the Assistant
64
+ 1. Enter your API key in the secure input field
65
+ 2. Describe your business problem in natural language:
66
+ - "I want to reduce customer churn in my SaaS business"
67
+ - "How can I implement fraud detection for my e-commerce platform"
68
+ - "What ML approach works best for demand forecasting in retail"
69
+
70
+ ### 3. Get AI-Powered Results
71
+ - **Solution Approach**: Detailed technical recommendations
72
+ - **Company Examples**: Real implementations from similar businesses
73
+ - **Model Recommendations**: Specific HuggingFace models for your use case
74
+
75
+ ## Model Information
76
+
77
+ This space uses pre-computed ChromaDB embeddings generated from a curated dataset of ML/AI use cases. The language model runs efficiently on CPU with fallback options for reliability.
78
+
79
+ ## Requirements & Limitations
80
+
81
+ ### Requirements
82
+ - Valid HuggingFace API key (free to obtain)
83
+ - Internet connection for API calls
84
+
85
+ ### Limitations
86
+ - Responses are generated based on training data patterns
87
+ - Model recommendations are sourced from HuggingFace Hub API
88
+ - Processing time may vary based on query complexity and API response times
89
+ - API rate limits apply based on your HuggingFace account tier
90
+
91
+ ---
92
+
93
+ *Built with ❤️ using HuggingFace Spaces*
app.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Header
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import chromadb
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline
9
+ from huggingface_hub import login
10
+ import requests
11
+ import json
12
+ from typing import List, Dict, Any
13
+ import os
14
+ import sys
15
+ import torch
16
+ import tarfile
17
+
18
+ app = FastAPI(title="ML Use Cases RAG System")
19
+
20
+ # Add CORS middleware
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # Global variable to store current logs
30
+ current_logs = []
31
+
32
+ def log_to_ui(message):
33
+ """Add a log message that will be sent to UI"""
34
+ current_logs.append(message)
35
+ print(message) # Still print to console
36
+
37
+ # Initialize embedding model
38
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
39
+
40
+ # Initialize Llama 3.2 3B model using transformers pipeline with remote inference
41
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
42
+
43
+ try:
44
+ if HUGGINGFACE_API_KEY:
45
+ print("🔐 Logging in to HuggingFace for gated model access...")
46
+ login(token=HUGGINGFACE_API_KEY)
47
+ print("✅ Successfully logged in to HuggingFace")
48
+
49
+ print("Initializing Gemma 2 2B via transformers pipeline (remote inference)...")
50
+ generator = pipeline(
51
+ "text-generation",
52
+ model="google/gemma-2-2b-it",
53
+ token=HUGGINGFACE_API_KEY # Updated parameter name
54
+ )
55
+ print("✅ Gemma 2 2B model initialized successfully")
56
+ llm_available = True
57
+ else:
58
+ print("No HuggingFace API key found - will use template responses")
59
+ generator = None
60
+ llm_available = False
61
+ except Exception as e:
62
+ print(f"Error initializing Gemma 2 2B: {e}")
63
+ print("Falling back to template responses")
64
+ generator = None
65
+ llm_available = False
66
+
67
+ # Auto-extract ChromaDB if archive exists and directory is missing/empty
68
+ def setup_chromadb():
69
+ """Setup ChromaDB by extracting archive if needed"""
70
+ if os.path.exists("chroma_db_complete.tar.gz"):
71
+ # Check if chroma_db directory exists and has content
72
+ needs_extraction = False
73
+
74
+ if not os.path.exists("chroma_db"):
75
+ print("📦 ChromaDB directory not found, extracting archive...")
76
+ needs_extraction = True
77
+ else:
78
+ # Check if directory is empty or missing key files
79
+ try:
80
+ if not os.path.exists("chroma_db/chroma.sqlite3"):
81
+ print("📦 ChromaDB missing database file, extracting archive...")
82
+ needs_extraction = True
83
+ else:
84
+ # Quick check: try to list collections
85
+ temp_client = chromadb.PersistentClient(path="./chroma_db")
86
+ collections = temp_client.list_collections()
87
+ if len(collections) == 0:
88
+ print("📦 ChromaDB has no collections, extracting archive...")
89
+ needs_extraction = True
90
+ else:
91
+ print(f"✅ ChromaDB already setup with {len(collections)} collections")
92
+ except Exception as e:
93
+ print(f"📦 ChromaDB check failed ({e}), extracting archive...")
94
+ needs_extraction = True
95
+
96
+ if needs_extraction:
97
+ try:
98
+ print("🔧 Extracting ChromaDB archive...")
99
+ with tarfile.open("chroma_db_complete.tar.gz", "r:gz") as tar:
100
+ tar.extractall()
101
+ print("✅ ChromaDB extracted successfully")
102
+
103
+ # Verify extraction
104
+ if os.path.exists("chroma_db/chroma.sqlite3"):
105
+ print("✅ Database file found after extraction")
106
+ else:
107
+ print("❌ Database file missing after extraction")
108
+
109
+ except Exception as e:
110
+ print(f"❌ Failed to extract ChromaDB: {e}")
111
+ else:
112
+ print("📋 No ChromaDB archive found, using existing directory")
113
+
114
+ # Setup ChromaDB before initializing client
115
+ setup_chromadb()
116
+
117
+ # Initialize ChromaDB
118
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
119
+ collection = None
120
+
121
+ class ChatRequest(BaseModel):
122
+ query: str
123
+
124
+ class ApiKeyRequest(BaseModel):
125
+ api_key: str
126
+
127
+ class SearchResult(BaseModel):
128
+ company: str
129
+ industry: str
130
+ year: int
131
+ description: str
132
+ summary: str
133
+ similarity_score: float
134
+ url: str
135
+
136
+ class RecommendedModels(BaseModel):
137
+ fine_tuned: List[Dict[str, Any]]
138
+ general: List[Dict[str, Any]]
139
+
140
+ class ChatResponse(BaseModel):
141
+ solution_approach: str
142
+ company_examples: List[SearchResult]
143
+ recommended_models: RecommendedModels
144
+
145
+ @app.get("/health")
146
+ async def health_check():
147
+ """Health check endpoint"""
148
+ return {"status": "healthy"}
149
+
150
+ @app.post("/validate-key")
151
+ async def validate_api_key(request: ApiKeyRequest):
152
+ """Validate user's HuggingFace API key"""
153
+ api_key = request.api_key.strip()
154
+
155
+ if not api_key or not api_key.startswith('hf_'):
156
+ return {"valid": False, "error": "Invalid API key format. Must start with 'hf_'"}
157
+
158
+ try:
159
+ # Test the API key by making a simple request to HuggingFace API
160
+ test_response = requests.get(
161
+ "https://huggingface.co/api/whoami",
162
+ headers={"Authorization": f"Bearer {api_key}"},
163
+ timeout=10
164
+ )
165
+
166
+ if test_response.status_code == 200:
167
+ user_info = test_response.json()
168
+ return {"valid": True, "user": user_info.get("name", "Unknown")}
169
+ else:
170
+ return {"valid": False, "error": "Invalid API key or insufficient permissions"}
171
+
172
+ except requests.RequestException as e:
173
+ return {"valid": False, "error": "Failed to validate API key. Please check your connection."}
174
+ except Exception as e:
175
+ return {"valid": False, "error": "Validation failed. Please try again."}
176
+
177
+ @app.get("/logs")
178
+ async def get_logs():
179
+ """Get current log messages for UI"""
180
+ try:
181
+ logs_copy = current_logs.copy()
182
+ current_logs.clear()
183
+ return {"logs": logs_copy}
184
+ except Exception as e:
185
+ return {"logs": [], "error": str(e)}
186
+
187
+ @app.get("/test-logs")
188
+ async def test_logs():
189
+ """Test endpoint to verify logging works"""
190
+ log_to_ui("🧪 Test log message 1")
191
+ log_to_ui("🧪 Test log message 2")
192
+ log_to_ui("🧪 Test log message 3")
193
+ return {"message": "Test logs added"}
194
+
195
+ def initialize_collection():
196
+ """Initialize the ChromaDB collection with debug logging"""
197
+ global collection
198
+
199
+ # Debug: Check file system
200
+ print(f"🔍 Current working directory: {os.getcwd()}")
201
+ print(f"🔍 ChromaDB path exists: {os.path.exists('./chroma_db')}")
202
+
203
+ if os.path.exists('./chroma_db'):
204
+ try:
205
+ chroma_files = os.listdir('./chroma_db')
206
+ print(f"🔍 ChromaDB directory contents: {chroma_files}")
207
+
208
+ # Check for main database file
209
+ if 'chroma.sqlite3' in chroma_files:
210
+ print("✅ Found chroma.sqlite3")
211
+ else:
212
+ print("❌ chroma.sqlite3 NOT found")
213
+
214
+ # Check for UUID directories
215
+ uuid_dirs = [f for f in chroma_files if len(f) == 36 and '-' in f] # UUID format
216
+ if uuid_dirs:
217
+ print(f"✅ Found UUID directories: {uuid_dirs}")
218
+ for uuid_dir in uuid_dirs:
219
+ uuid_path = os.path.join('./chroma_db', uuid_dir)
220
+ if os.path.isdir(uuid_path):
221
+ uuid_files = os.listdir(uuid_path)
222
+ print(f"🔍 {uuid_dir} contents: {uuid_files}")
223
+ else:
224
+ print("❌ No UUID directories found")
225
+
226
+ except Exception as e:
227
+ print(f"❌ Error reading chroma_db directory: {e}")
228
+ else:
229
+ print("❌ chroma_db directory does not exist")
230
+
231
+ # Debug: Try to initialize ChromaDB client
232
+ try:
233
+ print("🔍 Attempting to initialize ChromaDB client...")
234
+ print(f"🔍 ChromaDB version: {chromadb.__version__}")
235
+
236
+ # List all collections
237
+ collections = chroma_client.list_collections()
238
+ print(f"🔍 Available collections: {[c.name for c in collections]}")
239
+
240
+ # Try to get the specific collection
241
+ collection = chroma_client.get_collection("ml_use_cases")
242
+ collection_count = collection.count()
243
+ print(f"✅ Found existing collection 'ml_use_cases' with {collection_count} documents")
244
+
245
+ except Exception as e:
246
+ print(f"❌ Collection initialization error: {type(e).__name__}: {e}")
247
+ print("📝 Will attempt to create collection during first use")
248
+ collection = None
249
+
250
+ # Initialize collection on import
251
+ initialize_collection()
252
+
253
+ @app.get("/", response_class=HTMLResponse)
254
+ async def root():
255
+ """Serve the main frontend"""
256
+ with open("static/index.html", "r") as f:
257
+ return HTMLResponse(f.read())
258
+
259
+ async def search_use_cases_internal(request: ChatRequest):
260
+ """Internal search function with detailed logging"""
261
+ log_to_ui(f"🔍 Search request received: '{request.query}'")
262
+
263
+ if not collection:
264
+ log_to_ui("❌ ChromaDB collection not initialized")
265
+ raise HTTPException(status_code=500, detail="Database not initialized")
266
+
267
+ query = request.query.lower()
268
+ log_to_ui(f"📝 Normalized query: '{query}'")
269
+
270
+ # Generate query embedding for semantic search
271
+ log_to_ui("🧠 Generating query embedding...")
272
+ query_embedding = embedding_model.encode([request.query]).tolist()[0]
273
+ log_to_ui(f"✅ Embedding generated, dimension: {len(query_embedding)}")
274
+
275
+ # Semantic search
276
+ log_to_ui("🔎 Performing semantic search...")
277
+ semantic_results = collection.query(
278
+ query_embeddings=[query_embedding],
279
+ n_results=15,
280
+ include=['metadatas', 'documents', 'distances']
281
+ )
282
+ log_to_ui(f"📊 Semantic search found {len(semantic_results['ids'][0])} results")
283
+
284
+ # Keyword-based search using where clause for exact matches
285
+ keyword_results = None
286
+ try:
287
+ log_to_ui("🔤 Performing keyword search...")
288
+ keyword_results = collection.query(
289
+ query_texts=[request.query],
290
+ n_results=10,
291
+ include=['metadatas', 'documents', 'distances']
292
+ )
293
+ log_to_ui(f"📝 Keyword search found {len(keyword_results['ids'][0])} results")
294
+ except Exception as e:
295
+ log_to_ui(f"⚠️ Keyword search failed: {e}")
296
+ pass
297
+
298
+ # Combine and rank results
299
+ combined_results = {}
300
+
301
+ # Process semantic results
302
+ for i in range(len(semantic_results['ids'][0])):
303
+ doc_id = semantic_results['ids'][0][i]
304
+ metadata = semantic_results['metadatas'][0][i]
305
+ similarity_score = 1 - semantic_results['distances'][0][i]
306
+
307
+ # Boost score for keyword matches in metadata
308
+ boost = 0
309
+ query_words = query.split()
310
+ for word in query_words:
311
+ if word in metadata.get('title', '').lower():
312
+ boost += 0.3
313
+ if word in metadata.get('description', '').lower():
314
+ boost += 0.2
315
+ if word in metadata.get('keywords', '').lower():
316
+ boost += 0.4
317
+ if word in metadata.get('industry', '').lower():
318
+ boost += 0.1
319
+
320
+ final_score = min(similarity_score + boost, 1.0)
321
+
322
+ combined_results[doc_id] = {
323
+ 'metadata': metadata,
324
+ 'summary': semantic_results['documents'][0][i],
325
+ 'score': final_score,
326
+ 'source': 'semantic'
327
+ }
328
+
329
+ # Process keyword results if available
330
+ if keyword_results:
331
+ for i in range(len(keyword_results['ids'][0])):
332
+ doc_id = keyword_results['ids'][0][i]
333
+ if doc_id not in combined_results:
334
+ metadata = keyword_results['metadatas'][0][i]
335
+ similarity_score = 1 - keyword_results['distances'][0][i]
336
+
337
+ combined_results[doc_id] = {
338
+ 'metadata': metadata,
339
+ 'summary': keyword_results['documents'][0][i],
340
+ 'score': similarity_score + 0.1, # Small boost for keyword matches
341
+ 'source': 'keyword'
342
+ }
343
+
344
+ # Sort by score and take top results
345
+ sorted_results = sorted(combined_results.values(), key=lambda x: x['score'], reverse=True)[:10]
346
+ log_to_ui(f"🎯 Combined and ranked results: {len(sorted_results)} final results")
347
+
348
+ search_results = []
349
+ for i, result in enumerate(sorted_results):
350
+ metadata = result['metadata']
351
+ search_results.append(SearchResult(
352
+ company=metadata.get('company', ''),
353
+ industry=metadata.get('industry', ''),
354
+ year=metadata.get('year', 2023),
355
+ description=metadata.get('description', ''),
356
+ summary=result['summary'],
357
+ similarity_score=result['score'],
358
+ url=metadata.get('url', '')
359
+ ))
360
+ log_to_ui(f" {i+1}. {metadata.get('company', 'Unknown')} - Score: {result['score']:.3f}")
361
+
362
+ log_to_ui(f"✅ Search completed, returning {len(search_results)} results")
363
+ return search_results
364
+
365
+ @app.post("/search")
366
+ async def search_use_cases(request: ChatRequest):
367
+ """Public search endpoint"""
368
+ results = await search_use_cases_internal(request)
369
+ return {"results": results}
370
+
371
+ async def generate_response_with_user_key(prompt: str, api_key: str, max_length: int = 500) -> str:
372
+ """Generate response using user's HuggingFace API key via Inference API"""
373
+ try:
374
+ # Use HuggingFace Inference API with user's key
375
+ api_url = "https://api-inference.huggingface.co/models/google/gemma-2-2b-it"
376
+ headers = {
377
+ "Authorization": f"Bearer {api_key}",
378
+ "Content-Type": "application/json"
379
+ }
380
+
381
+ payload = {
382
+ "inputs": prompt,
383
+ "parameters": {
384
+ "max_new_tokens": max_length,
385
+ "temperature": 0.7,
386
+ "do_sample": True,
387
+ "return_full_text": False
388
+ }
389
+ }
390
+
391
+ response = requests.post(api_url, headers=headers, json=payload, timeout=30)
392
+
393
+ if response.status_code == 200:
394
+ result = response.json()
395
+ if isinstance(result, list) and len(result) > 0:
396
+ generated_text = result[0].get('generated_text', '')
397
+ return generated_text.strip()
398
+ else:
399
+ return "Unable to generate response. Please try again."
400
+ elif response.status_code == 503:
401
+ # Model is loading, try fallback
402
+ return await try_fallback_model(prompt, api_key, max_length)
403
+ else:
404
+ raise Exception(f"API request failed with status {response.status_code}")
405
+
406
+ except Exception as e:
407
+ print(f"Error generating response with user API key: {e}")
408
+ return generate_template_response(prompt)
409
+
410
+ async def try_fallback_model(prompt: str, api_key: str, max_length: int = 500) -> str:
411
+ """Try fallback model when primary model is unavailable"""
412
+ try:
413
+ # Try a more readily available model as fallback
414
+ fallback_models = [
415
+ "microsoft/DialoGPT-medium",
416
+ "microsoft/DialoGPT-small",
417
+ "gpt2"
418
+ ]
419
+
420
+ for model_name in fallback_models:
421
+ try:
422
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
423
+ headers = {
424
+ "Authorization": f"Bearer {api_key}",
425
+ "Content-Type": "application/json"
426
+ }
427
+
428
+ payload = {
429
+ "inputs": prompt,
430
+ "parameters": {
431
+ "max_new_tokens": max_length,
432
+ "temperature": 0.7,
433
+ "do_sample": True,
434
+ "return_full_text": False
435
+ }
436
+ }
437
+
438
+ response = requests.post(api_url, headers=headers, json=payload, timeout=20)
439
+
440
+ if response.status_code == 200:
441
+ result = response.json()
442
+ if isinstance(result, list) and len(result) > 0:
443
+ generated_text = result[0].get('generated_text', '')
444
+ return generated_text.strip()
445
+
446
+ except:
447
+ continue
448
+
449
+ # If all models fail, return template
450
+ return generate_template_response(prompt)
451
+
452
+ except Exception as e:
453
+ return generate_template_response(prompt)
454
+
455
+ def generate_template_response(prompt: str) -> str:
456
+ """Generate a template response when AI models are not available"""
457
+ return f"""Based on the analysis of similar ML/AI implementations from companies in our database, here are key recommendations for your problem:
458
+
459
+ **Technical Approach:**
460
+ - Consider machine learning classification or prediction models
461
+ - Leverage data preprocessing and feature engineering
462
+ - Implement proper model validation and testing
463
+
464
+ **Implementation Strategy:**
465
+ - Start with a minimum viable model using existing data
466
+ - Iterate based on performance metrics
467
+ - Consider scalability and real-time requirements
468
+
469
+ **Key Considerations:**
470
+ - Data quality and availability
471
+ - Business metrics alignment
472
+ - Technical infrastructure requirements
473
+
474
+ This analysis is based on patterns from 310+ real-world ML implementations across various industries."""
475
+
476
+ @app.post("/chat", response_model=ChatResponse)
477
+ async def chat_with_rag(request: ChatRequest, x_hf_api_key: str = Header(None, alias="X-HF-API-Key")):
478
+ """Main RAG endpoint with user API key"""
479
+ # Validate user API key
480
+ if not x_hf_api_key or not x_hf_api_key.startswith('hf_'):
481
+ raise HTTPException(status_code=400, detail="Valid HuggingFace API key required")
482
+
483
+ # Clear previous logs and start fresh
484
+ current_logs.clear()
485
+
486
+ log_to_ui(f"🤖 Chat request received: '{request.query}'")
487
+
488
+ # First search for relevant use cases
489
+ log_to_ui("🔍 Getting relevant use cases...")
490
+ relevant_cases = await search_use_cases_internal(request)
491
+ top_cases = relevant_cases[:5] # Top 5 results
492
+ log_to_ui(f"📚 Using top {len(top_cases)} cases for context")
493
+
494
+ # Prepare context for LLM
495
+ log_to_ui("📝 Preparing context for LLM...")
496
+ context = "Here are relevant real-world ML/AI implementations:\n\n"
497
+ for i, case in enumerate(top_cases, 1):
498
+ context += f"Company: {case.company} ({case.industry}, {case.year})\n"
499
+ context += f"Description: {case.description}\n"
500
+ context += f"Implementation: {case.summary[:500]}...\n\n"
501
+ log_to_ui(f" {i}. {case.company} - {case.description}")
502
+
503
+ log_to_ui(f"📊 Context length: {len(context)} characters")
504
+
505
+ # Create prompt for language model
506
+ prompt = f"""Based on the following real ML/AI implementations from companies, provide advice for this business problem:
507
+
508
+ {context}
509
+
510
+ User Problem: {request.query}
511
+
512
+ Please provide a comprehensive solution approach that considers what has worked for these companies. Focus on:
513
+ 1. What type of ML/AI solution would address this problem
514
+ 2. Key technical approaches that have proven successful
515
+ 3. Implementation considerations based on the examples
516
+
517
+ Be specific and reference the examples when relevant.
518
+
519
+ Response:"""
520
+
521
+ log_to_ui(f"💭 Full prompt length: {len(prompt)} characters")
522
+
523
+ # Generate response using user's HuggingFace API key
524
+ log_to_ui("🚀 Generating AI response with user API key...")
525
+ try:
526
+ llm_response = await generate_response_with_user_key(prompt, x_hf_api_key, max_length=400)
527
+ log_to_ui(f"✅ AI response generated, length: {len(llm_response)} characters")
528
+ except Exception as e:
529
+ llm_response = f"Error generating AI response: {str(e)}"
530
+ log_to_ui(f"❌ AI response error: {e}")
531
+
532
+ # Get HuggingFace model recommendations using user's API key
533
+ log_to_ui("🤗 Getting HuggingFace model recommendations...")
534
+ recommended_models = await get_huggingface_models(request.query, top_cases, x_hf_api_key)
535
+ total_models = len(recommended_models.get("fine_tuned", [])) + len(recommended_models.get("general", []))
536
+ log_to_ui(f"🏷️ Found {total_models} recommended models")
537
+
538
+ log_to_ui("✅ Chat response complete!")
539
+
540
+ # Return response with logs included
541
+ return {
542
+ "solution_approach": llm_response,
543
+ "company_examples": [
544
+ {
545
+ "company": case.company,
546
+ "industry": case.industry,
547
+ "year": case.year,
548
+ "description": case.description,
549
+ "summary": case.summary,
550
+ "similarity_score": case.similarity_score,
551
+ "url": case.url
552
+ }
553
+ for case in top_cases
554
+ ],
555
+ "recommended_models": {
556
+ "fine_tuned": recommended_models.get("fine_tuned", []),
557
+ "general": recommended_models.get("general", [])
558
+ },
559
+ "logs": current_logs.copy() # Include all logs in the response
560
+ }
561
+
562
+ async def get_huggingface_models(query: str, relevant_cases: List = None, api_key: str = None) -> Dict[str, List[Dict[str, Any]]]:
563
+ """Get relevant ML models from HuggingFace based on query and similar use cases"""
564
+ log_to_ui(f"🔍 Analyzing query for ML task mapping: '{query}'")
565
+
566
+ try:
567
+ # Enhanced multi-dimensional classification system
568
+ business_domains = {
569
+ # Financial Services
570
+ "finance": ["fraud detection", "risk assessment", "algorithmic trading", "credit scoring"],
571
+ "banking": ["fraud detection", "credit scoring", "customer segmentation", "loan approval"],
572
+ "fintech": ["payment processing", "robo-advisor", "fraud detection", "credit scoring"],
573
+ "insurance": ["risk assessment", "claim processing", "fraud detection", "pricing optimization"],
574
+
575
+ # E-commerce & Retail
576
+ "ecommerce": ["recommendation systems", "demand forecasting", "price optimization", "customer segmentation"],
577
+ "retail": ["inventory management", "demand forecasting", "customer analytics", "supply chain"],
578
+ "marketplace": ["search ranking", "recommendation systems", "fraud detection", "seller analytics"],
579
+
580
+ # Healthcare & Life Sciences
581
+ "healthcare": ["medical imaging", "drug discovery", "patient risk prediction", "clinical decision support"],
582
+ "medical": ["diagnostic imaging", "treatment optimization", "patient monitoring", "clinical trials"],
583
+ "pharma": ["drug discovery", "clinical trials", "adverse event detection", "molecular analysis"],
584
+
585
+ # Technology & Media
586
+ "tech": ["user behavior analysis", "system optimization", "content moderation", "search ranking"],
587
+ "media": ["content recommendation", "audience analytics", "content generation", "sentiment analysis"],
588
+ "gaming": ["player behavior prediction", "game optimization", "content generation", "cheat detection"],
589
+
590
+ # Marketing & Advertising
591
+ "marketing": ["customer segmentation", "campaign optimization", "lead scoring", "attribution modeling"],
592
+ "advertising": ["ad targeting", "bid optimization", "creative optimization", "audience analytics"],
593
+ "social": ["sentiment analysis", "trend prediction", "content moderation", "influence analysis"]
594
+ }
595
+
596
+ problem_types = {
597
+ # Customer Analytics
598
+ "churn": {
599
+ "domain": "customer_analytics",
600
+ "task_type": "binary_classification",
601
+ "data_types": ["tabular", "behavioral"],
602
+ "complexity": "intermediate",
603
+ "models": ["xgboost", "lightgbm", "catboost", "random_forest"],
604
+ "hf_tasks": ["tabular-classification"],
605
+ "keywords": ["retention", "attrition", "leave", "cancel", "subscription"]
606
+ },
607
+ "segmentation": {
608
+ "domain": "customer_analytics",
609
+ "task_type": "clustering",
610
+ "data_types": ["tabular", "behavioral"],
611
+ "complexity": "intermediate",
612
+ "models": ["kmeans", "dbscan", "hierarchical", "gaussian_mixture"],
613
+ "hf_tasks": ["tabular-classification"],
614
+ "keywords": ["segment", "group", "persona", "cluster", "behavior"]
615
+ },
616
+
617
+ # Risk & Fraud
618
+ "fraud": {
619
+ "domain": "risk_management",
620
+ "task_type": "anomaly_detection",
621
+ "data_types": ["tabular", "graph", "time_series"],
622
+ "complexity": "advanced",
623
+ "models": ["isolation_forest", "autoencoder", "one_class_svm", "gnn"],
624
+ "hf_tasks": ["tabular-classification"],
625
+ "keywords": ["suspicious", "anomaly", "unusual", "scam", "fake"]
626
+ },
627
+ "risk": {
628
+ "domain": "risk_management",
629
+ "task_type": "regression",
630
+ "data_types": ["tabular", "time_series"],
631
+ "complexity": "advanced",
632
+ "models": ["ensemble", "deep_learning", "survival_analysis"],
633
+ "hf_tasks": ["tabular-regression"],
634
+ "keywords": ["probability", "likelihood", "exposure", "default", "loss"]
635
+ },
636
+
637
+ # Demand & Forecasting
638
+ "forecast": {
639
+ "domain": "demand_planning",
640
+ "task_type": "time_series_forecasting",
641
+ "data_types": ["time_series", "tabular"],
642
+ "complexity": "advanced",
643
+ "models": ["prophet", "lstm", "transformer", "arima"],
644
+ "hf_tasks": ["time-series-forecasting"],
645
+ "keywords": ["predict", "future", "trend", "seasonal", "demand", "sales"]
646
+ },
647
+ "demand": {
648
+ "domain": "demand_planning",
649
+ "task_type": "regression",
650
+ "data_types": ["time_series", "tabular"],
651
+ "complexity": "intermediate",
652
+ "models": ["xgboost", "lstm", "prophet"],
653
+ "hf_tasks": ["tabular-regression", "time-series-forecasting"],
654
+ "keywords": ["inventory", "supply", "planning", "optimization"]
655
+ },
656
+
657
+ # Content & NLP
658
+ "sentiment": {
659
+ "domain": "nlp",
660
+ "task_type": "text_classification",
661
+ "data_types": ["text"],
662
+ "complexity": "beginner",
663
+ "models": ["bert", "roberta", "distilbert"],
664
+ "hf_tasks": ["text-classification"],
665
+ "keywords": ["opinion", "emotion", "feeling", "review", "feedback"]
666
+ },
667
+ "recommendation": {
668
+ "domain": "personalization",
669
+ "task_type": "recommendation",
670
+ "data_types": ["tabular", "behavioral", "content"],
671
+ "complexity": "advanced",
672
+ "models": ["collaborative_filtering", "content_based", "deep_learning"],
673
+ "hf_tasks": ["tabular-regression"],
674
+ "keywords": ["suggest", "personalize", "similar", "like", "prefer"]
675
+ },
676
+
677
+ # Pricing & Optimization
678
+ "pricing": {
679
+ "domain": "revenue_optimization",
680
+ "task_type": "regression",
681
+ "data_types": ["tabular", "time_series"],
682
+ "complexity": "advanced",
683
+ "models": ["ensemble", "reinforcement_learning", "optimization"],
684
+ "hf_tasks": ["tabular-regression"],
685
+ "keywords": ["price", "cost", "revenue", "profit", "optimize"]
686
+ }
687
+ }
688
+
689
+ # Advanced query analysis
690
+ def analyze_query_intent(query_text, cases=None):
691
+ """Analyze query to extract business domain, problem type, and complexity"""
692
+ query_lower = query_text.lower()
693
+
694
+ # Extract business domain
695
+ detected_domain = None
696
+ domain_confidence = 0
697
+ for domain, use_cases in business_domains.items():
698
+ if domain in query_lower:
699
+ detected_domain = domain
700
+ domain_confidence = 0.9
701
+ break
702
+ # Check use case matches
703
+ for use_case in use_cases:
704
+ if use_case.lower() in query_lower:
705
+ detected_domain = domain
706
+ domain_confidence = 0.7
707
+ break
708
+ if detected_domain:
709
+ break
710
+
711
+ # Extract problem type with scoring
712
+ problem_matches = []
713
+ for problem_name, problem_info in problem_types.items():
714
+ score = 0
715
+
716
+ # Direct problem name match
717
+ if problem_name in query_lower:
718
+ score += 50
719
+
720
+ # Keyword matches
721
+ for keyword in problem_info["keywords"]:
722
+ if keyword in query_lower:
723
+ score += 10
724
+
725
+ # Context from relevant cases
726
+ if cases:
727
+ case_text = " ".join([f"{case.description} {case.summary[:300]}" for case in cases]).lower()
728
+ if problem_name in case_text:
729
+ score += 20
730
+ for keyword in problem_info["keywords"]:
731
+ if keyword in case_text:
732
+ score += 5
733
+
734
+ if score > 0:
735
+ problem_matches.append((problem_name, problem_info, score))
736
+
737
+ # Sort by score and get best matches
738
+ problem_matches.sort(key=lambda x: x[2], reverse=True)
739
+
740
+ return detected_domain, problem_matches[:3], domain_confidence
741
+
742
+ # Analyze the query
743
+ query_lower = query.lower()
744
+ detected_domain, top_problems, domain_confidence = analyze_query_intent(query, relevant_cases)
745
+
746
+ # Determine primary task and approach
747
+ if top_problems:
748
+ primary_problem = top_problems[0]
749
+ problem_info = primary_problem[1]
750
+ primary_task = problem_info["hf_tasks"][0] if problem_info["hf_tasks"] else "tabular-classification"
751
+ complexity = problem_info["complexity"]
752
+ preferred_models = problem_info["models"]
753
+
754
+ log_to_ui(f"🎯 Detected problem: '{primary_problem[0]}' (score: {primary_problem[2]})")
755
+ log_to_ui(f"📊 Domain: {detected_domain or 'general'} | Complexity: {complexity}")
756
+ log_to_ui(f"🔧 Preferred models: {', '.join(preferred_models[:3])}")
757
+ else:
758
+ # Fallback to simple keyword matching
759
+ primary_task = "tabular-classification"
760
+ complexity = "intermediate"
761
+ preferred_models = ["xgboost", "lightgbm"]
762
+ log_to_ui(f"📊 Using fallback classification | Task: {primary_task}")
763
+
764
+ matched_keywords = [p[0] for p in top_problems]
765
+
766
+ log_to_ui(f"📊 Primary task: '{primary_task}' | Keywords: {matched_keywords}")
767
+
768
+ # Search for models with multiple strategies
769
+ all_models = []
770
+
771
+ # Strategy 1: Search by primary task
772
+ models_primary = await search_models_by_task(primary_task, api_key)
773
+ all_models.extend(models_primary)
774
+
775
+ # Strategy 2: Search by specific keywords for better matches
776
+ if matched_keywords:
777
+ for keyword in matched_keywords[:2]: # Top 2 keywords
778
+ keyword_models = await search_models_by_keyword(keyword, api_key)
779
+ all_models.extend(keyword_models)
780
+
781
+ # Strategy 3: Search for domain-specific models
782
+ domain_searches = []
783
+ if "churn" in query_lower or "retention" in query_lower:
784
+ domain_searches.append("customer-analytics")
785
+ if "fraud" in query_lower:
786
+ domain_searches.append("anomaly-detection")
787
+ if "recommend" in query_lower:
788
+ domain_searches.append("recommendation")
789
+
790
+ for domain in domain_searches:
791
+ domain_models = await search_models_by_keyword(domain, api_key)
792
+ all_models.extend(domain_models)
793
+
794
+ # Remove duplicates and rank by relevance
795
+ seen_models = set()
796
+ unique_models = []
797
+
798
+ for model in all_models:
799
+ model_id = model.get("id") or model.get("name")
800
+ if model_id and model_id not in seen_models:
801
+ seen_models.add(model_id)
802
+ unique_models.append(model)
803
+
804
+ # Score models based on enhanced relevance criteria
805
+ scored_models = []
806
+ for model in unique_models:
807
+ score = calculate_model_relevance(
808
+ model, query_lower, matched_keywords,
809
+ complexity, preferred_models if 'preferred_models' in locals() else None
810
+ )
811
+ scored_models.append((model, score))
812
+
813
+ # Separate models into fine-tuned/specific vs general base models
814
+ fine_tuned_models = []
815
+ general_models = []
816
+
817
+ for model, score in scored_models:
818
+ if is_fine_tuned_model(model, matched_keywords):
819
+ fine_tuned_models.append((model, score))
820
+ elif is_general_suitable_model(model, primary_task):
821
+ general_models.append((model, score))
822
+
823
+ # Sort and take top 3 of each type
824
+ fine_tuned_models.sort(key=lambda x: x[1], reverse=True)
825
+ general_models.sort(key=lambda x: x[1], reverse=True)
826
+
827
+ top_fine_tuned = [model for model, score in fine_tuned_models[:3]]
828
+ top_general = [model for model, score in general_models[:3]]
829
+
830
+ # Add curated high-quality models for specific use cases
831
+ def get_curated_models(problem_type: str, complexity_level: str) -> List[Dict]:
832
+ """Get curated high-quality models for specific use cases"""
833
+ curated = {
834
+ "churn": {
835
+ "beginner": [
836
+ {"id": "scikit-learn/RandomForestClassifier", "task": "tabular-classification"},
837
+ {"id": "xgboost/XGBClassifier", "task": "tabular-classification"}
838
+ ],
839
+ "intermediate": [
840
+ {"id": "microsoft/TabNet", "task": "tabular-classification"},
841
+ {"id": "AutoML/AutoGluon-Tabular", "task": "tabular-classification"}
842
+ ],
843
+ "advanced": [
844
+ {"id": "microsoft/LightGBM", "task": "tabular-classification"},
845
+ {"id": "dmlc/xgboost", "task": "tabular-classification"}
846
+ ]
847
+ },
848
+ "sentiment": {
849
+ "beginner": [
850
+ {"id": "cardiffnlp/twitter-roberta-base-sentiment-latest", "task": "text-classification"},
851
+ {"id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"}
852
+ ],
853
+ "intermediate": [
854
+ {"id": "nlptown/bert-base-multilingual-uncased-sentiment", "task": "text-classification"},
855
+ {"id": "microsoft/DialoGPT-medium", "task": "text-classification"}
856
+ ],
857
+ "advanced": [
858
+ {"id": "roberta-large-mnli", "task": "text-classification"},
859
+ {"id": "microsoft/deberta-v3-large", "task": "text-classification"}
860
+ ]
861
+ },
862
+ "fraud": {
863
+ "intermediate": [
864
+ {"id": "microsoft/TabNet", "task": "tabular-classification"},
865
+ {"id": "IsolationForest/AnomalyDetection", "task": "tabular-classification"}
866
+ ],
867
+ "advanced": [
868
+ {"id": "pyod/AutoEncoder", "task": "tabular-classification"},
869
+ {"id": "GraphNeuralNetworks/FraudDetection", "task": "tabular-classification"}
870
+ ]
871
+ },
872
+ "forecast": {
873
+ "intermediate": [
874
+ {"id": "facebook/prophet", "task": "time-series-forecasting"},
875
+ {"id": "statsmodels/ARIMA", "task": "time-series-forecasting"}
876
+ ],
877
+ "advanced": [
878
+ {"id": "microsoft/DeepAR", "task": "time-series-forecasting"},
879
+ {"id": "google/temporal-fusion-transformer", "task": "time-series-forecasting"}
880
+ ]
881
+ }
882
+ }
883
+
884
+ # Get curated models for the specific problem and complexity
885
+ if problem_type in curated and complexity_level in curated[problem_type]:
886
+ return curated[problem_type][complexity_level]
887
+ elif problem_type in curated:
888
+ # Fallback to any complexity level available
889
+ for level in ["beginner", "intermediate", "advanced"]:
890
+ if level in curated[problem_type]:
891
+ return curated[problem_type][level]
892
+
893
+ return []
894
+
895
+ # Add curated models
896
+ if top_problems:
897
+ primary_problem_name = top_problems[0][0]
898
+ curated_models = get_curated_models(primary_problem_name, complexity)
899
+ for curated_model in curated_models:
900
+ if len(top_general) < 3:
901
+ # Format as HuggingFace model dict
902
+ formatted_model = {
903
+ "id": curated_model["id"],
904
+ "pipeline_tag": curated_model["task"],
905
+ "downloads": 50000, # Reasonable default
906
+ "tags": ["curated", "production-ready"]
907
+ }
908
+ top_general.append(formatted_model)
909
+
910
+ # Add general foundation models if we still don't have enough
911
+ if len(top_general) < 3:
912
+ foundation_models = await get_foundation_models(primary_task, matched_keywords, api_key)
913
+ top_general.extend(foundation_models[:3-len(top_general)])
914
+
915
+ # Format response with categories
916
+ model_response = {
917
+ "fine_tuned": [],
918
+ "general": []
919
+ }
920
+
921
+ # Enhanced model descriptions based on detected problem type
922
+ def get_enhanced_model_description(model: Dict, model_type: str, problem_context: str = None) -> str:
923
+ """Generate context-aware model descriptions"""
924
+ model_name = model.get("id", "").lower()
925
+
926
+ if model_type == "fine-tuned":
927
+ if problem_context == "churn":
928
+ return "Pre-trained model optimized for customer retention prediction"
929
+ elif problem_context == "fraud":
930
+ return "Specialized anomaly detection model for fraud identification"
931
+ elif problem_context == "sentiment":
932
+ return "Fine-tuned sentiment analysis model for text classification"
933
+ elif problem_context == "forecast":
934
+ return "Time series forecasting model for demand prediction"
935
+ else:
936
+ return f"Specialized model fine-tuned for {get_model_specialty(model, matched_keywords)}"
937
+ else: # general
938
+ if "curated" in str(model.get("tags", [])):
939
+ return "Production-ready model recommended for business use cases"
940
+ elif any(term in model_name for term in ["bert", "roberta", "distilbert"]):
941
+ return "Transformer-based foundation model for fine-tuning"
942
+ elif any(term in model_name for term in ["xgboost", "lightgbm", "catboost"]):
943
+ return "Gradient boosting model excellent for tabular data"
944
+ elif "prophet" in model_name:
945
+ return "Facebook's time series forecasting framework"
946
+ else:
947
+ return f"Foundation model suitable for {primary_task.replace('-', ' ')}"
948
+
949
+ # Format fine-tuned models with enhanced descriptions
950
+ primary_problem_name = top_problems[0][0] if top_problems else None
951
+
952
+ for model in top_fine_tuned:
953
+ model_info = {
954
+ "name": model.get("id", model.get("name", "Unknown")),
955
+ "downloads": model.get("downloads", 0),
956
+ "task": model.get("pipeline_tag", primary_task),
957
+ "url": f"https://huggingface.co/{model.get('id', '')}",
958
+ "type": "fine-tuned",
959
+ "description": get_enhanced_model_description(model, "fine-tuned", primary_problem_name)
960
+ }
961
+ model_response["fine_tuned"].append(model_info)
962
+
963
+ # Format general models with enhanced descriptions
964
+ for model in top_general:
965
+ model_info = {
966
+ "name": model.get("id", model.get("name", "Unknown")),
967
+ "downloads": model.get("downloads", 0),
968
+ "task": model.get("pipeline_tag", primary_task),
969
+ "url": f"https://huggingface.co/{model.get('id', '')}",
970
+ "type": "general",
971
+ "description": get_enhanced_model_description(model, "general", primary_problem_name)
972
+ }
973
+ model_response["general"].append(model_info)
974
+
975
+ total_models = len(model_response["fine_tuned"]) + len(model_response["general"])
976
+ log_to_ui(f"📦 Found {len(model_response['fine_tuned'])} fine-tuned + {len(model_response['general'])} general models")
977
+
978
+ # Log details
979
+ if model_response["fine_tuned"]:
980
+ log_to_ui("🎯 Fine-tuned/Specialized models:")
981
+ for i, model in enumerate(model_response["fine_tuned"], 1):
982
+ log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads")
983
+
984
+ if model_response["general"]:
985
+ log_to_ui("🔧 General/Foundation models:")
986
+ for i, model in enumerate(model_response["general"], 1):
987
+ log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads")
988
+
989
+ return model_response
990
+
991
+ except Exception as e:
992
+ log_to_ui(f"❌ Error fetching HuggingFace models: {e}")
993
+ return {"fine_tuned": [], "general": []}
994
+
995
+ async def search_models_by_task(task: str, api_key: str = None) -> List[Dict]:
996
+ """Search models by specific task"""
997
+ try:
998
+ headers = {}
999
+ if api_key:
1000
+ headers["Authorization"] = f"Bearer {api_key}"
1001
+
1002
+ response = requests.get(
1003
+ f"https://huggingface.co/api/models?pipeline_tag={task}&sort=downloads&limit=10",
1004
+ headers=headers,
1005
+ timeout=10
1006
+ )
1007
+ if response.status_code == 200:
1008
+ return response.json()
1009
+ except:
1010
+ pass
1011
+ return []
1012
+
1013
+ async def search_models_by_keyword(keyword: str, api_key: str = None) -> List[Dict]:
1014
+ """Search models by keyword in name/description"""
1015
+ try:
1016
+ headers = {}
1017
+ if api_key:
1018
+ headers["Authorization"] = f"Bearer {api_key}"
1019
+
1020
+ response = requests.get(
1021
+ f"https://huggingface.co/api/models?search={keyword}&sort=downloads&limit=10",
1022
+ headers=headers,
1023
+ timeout=10
1024
+ )
1025
+ if response.status_code == 200:
1026
+ return response.json()
1027
+ except:
1028
+ pass
1029
+ return []
1030
+
1031
+ def calculate_model_relevance(model: Dict, query: str, keywords: List[str],
1032
+ complexity: str = "intermediate", preferred_models: List[str] = None) -> float:
1033
+ """Enhanced multi-criteria model relevance scoring"""
1034
+ score = 0
1035
+ model_name = model.get("id", "").lower()
1036
+ model_task = model.get("pipeline_tag", "").lower()
1037
+ downloads = model.get("downloads", 0)
1038
+
1039
+ # 1. Base popularity score (0-15 points)
1040
+ if downloads > 10000000: # 10M+
1041
+ score += 15
1042
+ elif downloads > 1000000: # 1M+
1043
+ score += 12
1044
+ elif downloads > 100000: # 100K+
1045
+ score += 8
1046
+ elif downloads > 10000: # 10K+
1047
+ score += 5
1048
+ elif downloads > 1000: # 1K+
1049
+ score += 2
1050
+
1051
+ # 2. Direct keyword relevance (0-30 points)
1052
+ for keyword in keywords:
1053
+ if keyword in model_name:
1054
+ score += 25
1055
+ # Check in model description/tags if available
1056
+ model_tags = model.get("tags", [])
1057
+ if any(keyword in str(tag).lower() for tag in model_tags):
1058
+ score += 15
1059
+
1060
+ # 3. Query term matches (0-20 points)
1061
+ query_words = [w for w in query.lower().split() if len(w) > 3]
1062
+ for word in query_words:
1063
+ if word in model_name:
1064
+ score += 8
1065
+ if word in str(model.get("tags", [])).lower():
1066
+ score += 5
1067
+
1068
+ # 4. Preferred model architecture bonus (0-25 points)
1069
+ if preferred_models:
1070
+ for preferred in preferred_models:
1071
+ if preferred.lower() in model_name:
1072
+ score += 20
1073
+ break
1074
+ # Partial matches
1075
+ for preferred in preferred_models:
1076
+ if any(part in model_name for part in preferred.lower().split('_')):
1077
+ score += 10
1078
+ break
1079
+
1080
+ # 5. Task alignment (0-20 points)
1081
+ relevant_tasks = ["tabular-classification", "tabular-regression", "text-classification",
1082
+ "time-series-forecasting", "question-answering"]
1083
+ if model_task in relevant_tasks:
1084
+ score += 15
1085
+
1086
+ # 6. Complexity alignment (0-15 points)
1087
+ complexity_indicators = {
1088
+ "beginner": ["base", "simple", "basic", "distil", "small", "mini"],
1089
+ "intermediate": ["medium", "standard", "v2", "improved"],
1090
+ "advanced": ["large", "xl", "xxl", "advanced", "complex", "ensemble"]
1091
+ }
1092
+
1093
+ if complexity in complexity_indicators:
1094
+ for indicator in complexity_indicators[complexity]:
1095
+ if indicator in model_name:
1096
+ score += 12
1097
+ break
1098
+
1099
+ # 7. Production readiness indicators (0-10 points)
1100
+ production_terms = ["production", "optimized", "efficient", "fast", "deployment"]
1101
+ for term in production_terms:
1102
+ if term in model_name:
1103
+ score += 8
1104
+ break
1105
+
1106
+ # 8. Penalties for problematic models (negative points)
1107
+ penalty_terms = ["nsfw", "adult", "sexual", "violence", "toxic", "unsafe", "experimental"]
1108
+ for term in penalty_terms:
1109
+ if term in model_name:
1110
+ score -= 30
1111
+
1112
+ # Generic model penalty
1113
+ generic_terms = ["general", "random", "test", "example", "demo"]
1114
+ for term in generic_terms:
1115
+ if term in model_name:
1116
+ score -= 10
1117
+
1118
+ # 9. Model quality indicators (0-10 points)
1119
+ quality_terms = ["sota", "benchmark", "award", "winner", "best", "top"]
1120
+ for term in quality_terms:
1121
+ if term in model_name or term in str(model.get("tags", [])).lower():
1122
+ score += 8
1123
+ break
1124
+
1125
+ # 10. Recency bonus (0-5 points) - prefer newer models
1126
+ # This would require model creation date, approximating with model name patterns
1127
+ recent_indicators = ["2024", "2023", "v3", "v4", "v5", "latest", "new"]
1128
+ for indicator in recent_indicators:
1129
+ if indicator in model_name:
1130
+ score += 3
1131
+ break
1132
+
1133
+ return max(score, 0)
1134
+
1135
+ def is_fine_tuned_model(model: Dict, keywords: List[str]) -> bool:
1136
+ """Determine if a model is fine-tuned/specialized for the specific task"""
1137
+ model_name = model.get("id", "").lower()
1138
+
1139
+ # Models with specific task keywords in name are likely fine-tuned
1140
+ for keyword in keywords:
1141
+ if keyword in model_name:
1142
+ return True
1143
+
1144
+ # Models with specific fine-tuning indicators
1145
+ fine_tuned_indicators = [
1146
+ "fine-tuned", "ft", "finetuned", "specialized", "custom",
1147
+ "churn", "fraud", "sentiment", "classification", "detection",
1148
+ "prediction", "analytics", "recommendation", "recommender"
1149
+ ]
1150
+
1151
+ for indicator in fine_tuned_indicators:
1152
+ if indicator in model_name:
1153
+ return True
1154
+
1155
+ # Models from specific companies/domains (often specialized)
1156
+ domain_indicators = ["customer", "business", "financial", "ecommerce", "retail"]
1157
+ for domain in domain_indicators:
1158
+ if domain in model_name:
1159
+ return True
1160
+
1161
+ return False
1162
+
1163
+ def is_general_suitable_model(model: Dict, primary_task: str) -> bool:
1164
+ """Determine if a model is a general foundation model suitable for the task"""
1165
+ model_name = model.get("id", "").lower()
1166
+ model_task = model.get("pipeline_tag", "").lower()
1167
+
1168
+ # Check if model task matches primary task
1169
+ if model_task == primary_task:
1170
+ return True
1171
+
1172
+ # General foundation models (base models good for fine-tuning)
1173
+ foundation_indicators = [
1174
+ "base", "large", "xlarge", "bert", "roberta", "distilbert",
1175
+ "electra", "albert", "transformer", "gpt", "t5", "bart",
1176
+ "deberta", "xlnet", "longformer"
1177
+ ]
1178
+
1179
+ for indicator in foundation_indicators:
1180
+ if indicator in model_name and not any(x in model_name for x in ["nsfw", "safety", "toxicity"]):
1181
+ return True
1182
+
1183
+ return False
1184
+
1185
+ async def get_foundation_models(primary_task: str, keywords: List[str], api_key: str = None) -> List[Dict]:
1186
+ """Get well-known foundation models suitable for the task"""
1187
+ foundation_searches = []
1188
+
1189
+ if primary_task in ["text-classification", "token-classification"]:
1190
+ foundation_searches = [
1191
+ "bert-base-uncased", "roberta-base", "distilbert-base-uncased",
1192
+ "microsoft/deberta-v3-base", "google/electra-base-discriminator"
1193
+ ]
1194
+ elif primary_task in ["tabular-classification", "tabular-regression"]:
1195
+ foundation_searches = [
1196
+ "scikit-learn", "xgboost", "lightgbm", "catboost", "pytorch-tabular"
1197
+ ]
1198
+ elif primary_task in ["text-generation", "conversational"]:
1199
+ foundation_searches = [
1200
+ "gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot"
1201
+ ]
1202
+ elif primary_task in ["question-answering"]:
1203
+ foundation_searches = [
1204
+ "bert-base-uncased", "distilbert-base-uncased", "roberta-base"
1205
+ ]
1206
+
1207
+ models = []
1208
+ for search_term in foundation_searches[:3]: # Top 3 foundation searches
1209
+ try:
1210
+ headers = {}
1211
+ if api_key:
1212
+ headers["Authorization"] = f"Bearer {api_key}"
1213
+
1214
+ response = requests.get(
1215
+ f"https://huggingface.co/api/models?search={search_term}&sort=downloads&limit=3",
1216
+ headers=headers,
1217
+ timeout=10
1218
+ )
1219
+ if response.status_code == 200:
1220
+ models.extend(response.json())
1221
+ except:
1222
+ continue
1223
+
1224
+ return models[:3] # Return top 3
1225
+
1226
+ def get_model_specialty(model: Dict, keywords: List[str]) -> str:
1227
+ """Get human-readable specialty description for a model"""
1228
+ model_name = model.get("id", "").lower()
1229
+
1230
+ # Map keywords to descriptions
1231
+ specialty_map = {
1232
+ "churn": "customer churn prediction",
1233
+ "fraud": "fraud detection",
1234
+ "sentiment": "sentiment analysis",
1235
+ "recommendation": "recommendation systems",
1236
+ "classification": "classification tasks",
1237
+ "detection": "detection tasks",
1238
+ "prediction": "prediction tasks"
1239
+ }
1240
+
1241
+ for keyword in keywords:
1242
+ if keyword in specialty_map:
1243
+ return specialty_map[keyword]
1244
+
1245
+ # Fallback: try to infer from model name
1246
+ if "churn" in model_name:
1247
+ return "customer churn prediction"
1248
+ elif "fraud" in model_name:
1249
+ return "fraud detection"
1250
+ elif "sentiment" in model_name:
1251
+ return "sentiment analysis"
1252
+ elif "recommend" in model_name:
1253
+ return "recommendation systems"
1254
+ else:
1255
+ return "specialized ML tasks"
1256
+
1257
+ # Serve static files
1258
+ app.mount("/static", StaticFiles(directory="static"), name="static")
1259
+
1260
+ if __name__ == "__main__":
1261
+ import uvicorn
1262
+ uvicorn.run(app, host="0.0.0.0", port=7860) # HF Spaces uses port 7860
chroma_db_complete.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e9ae444eee4049218ca44a3490cfaa8d3d5c80d2453cf24904bf8cf8ec0bacf
3
+ size 2617874
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ chromadb>=1.0.15
4
+ sentence-transformers==2.7.0
5
+ transformers==4.55.4
6
+ torch==2.1.2
7
+ huggingface-hub>=0.34.0
8
+ pandas==2.1.4
9
+ requests==2.31.0
10
+ python-multipart==0.0.6
11
+ jinja2==3.1.2