| from dotenv import load_dotenv | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from groq import Groq | |
| load_dotenv() | |
| api = os.getenv("groq_api_key") | |
| def create_metadata_embeddings(): | |
| student=""" | |
| Table: student | |
| Columns: | |
| - student_id: an integer representing the unique ID of a student. | |
| - first_name: a string containing the first name of the student. | |
| - last_name: a string containing the last name of the student. | |
| - date_of_birth: a date representing the student's birthdate. | |
| - email: a string for the student's email address. | |
| - phone_number: a string for the student's contact number. | |
| - major: a string representing the student's major field of study. | |
| - year_of_enrollment: an integer for the year the student enrolled. | |
| """ | |
| employee=""" | |
| Table: employee | |
| Columns: | |
| - employee_id: an integer representing the unique ID of an employee. | |
| - first_name: a string containing the first name of the employee. | |
| - last_name: a string containing the last name of the employee. | |
| - email: a string for the employee's email address. | |
| - department: a string for the department the employee works in. | |
| - position: a string representing the employee's job title. | |
| - salary: a float representing the employee's salary. | |
| - date_of_joining: a date for when the employee joined the college. | |
| """ | |
| course=""" | |
| Table: course_info | |
| Columns: | |
| - course_id: an integer representing the unique ID of the course. | |
| - course_name: a string containing the course's name. | |
| - course_code: a string for the course's unique code. | |
| - instructor_id: an integer for the ID of the instructor teaching the course. | |
| - department: a string for the department offering the course. | |
| - credits: an integer representing the course credits. | |
| - semester: a string for the semester when the course is offered. | |
| """ | |
| metadata_list = [student, employee, course] | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| embeddings = model.encode(metadata_list) | |
| return embeddings,model,student,employee,course | |
| def find_best_fit(embeddings,model,user_query,student,employee,course): | |
| query_embedding = model.encode([user_query]) | |
| similarities = cosine_similarity(query_embedding, embeddings) | |
| best_match_table = similarities.argmax() | |
| if(best_match_table==0): | |
| table_metadata=student | |
| elif(best_match_table==1): | |
| table_metadata=employee | |
| else: | |
| table_metadata=course | |
| return table_metadata | |
| def create_prompt(user_query,table_metadata): | |
| system_prompt=""" | |
| You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata. | |
| Rules: | |
| Single Table Only: Assume all queries are related to a single table provided in the metadata. Ignore any references to other tables. | |
| Metadata-Based Validation: Always ensure the generated query matches the table name, columns, and data types provided in the metadata. | |
| User Intent: Accurately capture the user's requirements, such as filters, sorting, or aggregations, as expressed in natural language. | |
| SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems. | |
| Input Format: | |
| User Query: The user's natural language request. | |
| Table Metadata: The structure of the relevant table, including the table name, column names, and data types. | |
| Output Format: | |
| SQL Query: A valid SQL query formatted for readability. | |
| Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only. | |
| You are ready to generate SQL queries based on the user input and table metadata. | |
| """ | |
| user_prompt=f""" | |
| User Query: {user_query} | |
| Table Metadata: {table_metadata} | |
| """ | |
| return system_prompt,user_prompt | |
| def generate_output(system_prompt,user_prompt): | |
| client = Groq(api_key=api,) | |
| chat_completion = client.chat.completions.create(messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user","content": user_prompt,}],model="llama3-70b-8192",) | |
| res = chat_completion.choices[0].message.content | |
| select=res[0:6].lower() | |
| if(select=="select"): | |
| output=res | |
| else: | |
| output="Can't perform the task at the moment." | |
| return output | |
| def response(user_query): | |
| embeddings,model,student,employee,course=create_metadata_embeddings() | |
| table_metadata=find_best_fit(embeddings,model,user_query,student,employee,course) | |
| system_prompt,user_prompt=create_prompt(user_query,table_metadata) | |
| output=generate_output(system_prompt,user_prompt) | |
| return output | |
| desc=""" | |
| There are three tables in the database: | |
| Student Table: | |
| The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment. | |
| Employee Table: | |
| The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining. | |
| Course Info Table: | |
| The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered. | |
| """ | |
| demo = gr.Interface( | |
| fn=response, | |
| inputs=gr.Textbox(label="Please provide the natural language query"), | |
| outputs=gr.Textbox(label="SQL Query"), | |
| title="SQL Query generator", | |
| description=desc | |
| ) | |
| demo.launch(share="True") |