Spaces:
Running
Running
| import os | |
| import json | |
| import duckdb | |
| import gradio as gr | |
| import pandas as pd | |
| import pandera as pa | |
| from pandera import Column | |
| import ydata_profiling as pp | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| from langsmith import traceable | |
| from langchain import hub | |
| import warnings | |
| import dlt | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # Height of the Tabs Text Area | |
| TAB_LINES = 8 | |
| #----------CONNECT TO DATABASE---------- | |
| md_token = os.getenv('MD_TOKEN') | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
| #--------------------------------------- | |
| #-------LOAD HUGGINGFACE------- | |
| models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct"] | |
| model_loaded = False | |
| for model in models: | |
| try: | |
| endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) | |
| info = endpoint.client.get_endpoint_info() | |
| model_loaded = True | |
| break | |
| except Exception as e: | |
| print(f"Error for model {model}: {e}") | |
| continue | |
| llm = ChatHuggingFace(llm=endpoint).bind(max_tokens=8192) | |
| #--------------------------------------- | |
| #-----LOAD PROMPT FROM LANCHAIN HUB----- | |
| prompt_autogenerate = hub.pull("autogenerate-rules-testworkflow") | |
| prompt_user_input = hub.pull("usergenerate-rules-testworkflow") | |
| #--------------ALL UTILS---------------- | |
| # Get Databases | |
| def get_schemas(): | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| # Get Tables | |
| def get_tables_names(schema_name): | |
| tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
| return [table[0] for table in tables] | |
| # Update Tables | |
| def update_table_names(schema_name): | |
| tables = get_tables_names(schema_name) | |
| return gr.update(choices=tables) | |
| # def get_data_df(schema): | |
| # print('Getting Dataframe from the Database') | |
| # return conn.sql(f"SELECT * FROM {schema} LIMIT 1000") | |
| def fetch_data(schema): | |
| result = conn.sql(f"SELECT * FROM {schema} LIMIT 1000") | |
| while True: | |
| chunk_df = result.fetch_df_chunk(2) | |
| if chunk_df is None or len(chunk_df) == 0: | |
| break | |
| else: | |
| yield chunk_df | |
| def create_pipeline(schema): | |
| dataset_name = schema.split('.')[1] | |
| print("Dataset Name: ", dataset_name) | |
| table_name = schema.split('.')[2] | |
| print("Table Name: ", table_name) | |
| pipeline =dlt.pipeline( | |
| pipeline_name='duckdb_pipeline', | |
| destination='duckdb', | |
| dataset_name= dataset_name, | |
| ) | |
| load_info = pipeline.run(fetch_data(schema), table_name = table_name, | |
| write_disposition = "replace") | |
| print(load_info) | |
| return dataset_name + "." + table_name | |
| def load_pipeline(table_name): | |
| _conn = duckdb.connect("duckdb_pipeline.duckdb") | |
| return _conn, _conn.sql(f"SELECT * FROM {table_name} LIMIT 1000").df() | |
| def df_summary(df): | |
| summary = [] | |
| for column in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[column]): | |
| summary.append({ | |
| "column": column, | |
| "max": df[column].max(), | |
| "min": df[column].min(), | |
| "count": df[column].count(), | |
| "nunique": df[column].nunique(), | |
| "dtype": str(df[column].dtype), | |
| "top": None | |
| }) | |
| elif pd.api.types.is_categorical_dtype(df[column]) or pd.api.types.is_object_dtype(df[column]): | |
| top_value = df[column].mode().iloc[0] if not df[column].mode().empty else None | |
| summary.append({ | |
| "column": column, | |
| "max": None, | |
| "min": None, | |
| "count": df[column].count(), | |
| "nunique": df[column].nunique(), | |
| "dtype": str(df[column].dtype), | |
| "top": top_value | |
| }) | |
| summary_df = pd.DataFrame(summary) | |
| return summary_df.reset_index(drop=True) | |
| def format_prompt(df): | |
| summary = df_summary(df) | |
| return prompt_autogenerate.format_prompt(data=df.head().to_json(orient='records'), | |
| summary=summary.to_json(orient='records')) | |
| def format_user_prompt(df): | |
| return prompt_user_input.format_prompt(data=df.head().to_json(orient='records')) | |
| def process_inputs(inputs) : | |
| return {'input_query': inputs['messages'].to_messages()[1]} | |
| def run_llm(messages): | |
| try: | |
| response = llm.invoke(messages) | |
| print(response.content.replace("```", "'''").replace("json", "")) | |
| tests = json.loads(response.content.replace("```", "").replace("json", "")) | |
| except Exception as e: | |
| return e | |
| return tests | |
| # Get Schema | |
| def get_table_schema(table): | |
| result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
| ddl_create = result.iloc[0,0] | |
| parent_database = result.iloc[0,1] | |
| schema_name = result.iloc[0,2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return full_path | |
| def describe(df): | |
| numerical_info = pd.DataFrame() | |
| categorical_info = pd.DataFrame() | |
| if len(df.select_dtypes(include=['number']).columns) >= 1: | |
| numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() | |
| numerical_info.rename(columns={'index': 'column'}, inplace=True) | |
| if len(df.select_dtypes(include=['object']).columns) >= 1: | |
| categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() | |
| categorical_info.rename(columns={'index': 'column'}, inplace=True) | |
| return numerical_info, categorical_info | |
| def validate_pandera(tests, df): | |
| validation_results = [] | |
| for test in tests: | |
| column_name = test['column_name'] | |
| try: | |
| rule = eval(test['pandera_rule']) | |
| validated_column = rule(df[[column_name]]) | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": "✅ Pass" | |
| }) | |
| except Exception as e: | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": f"❌ Fail - {str(e)}" | |
| }) | |
| return pd.DataFrame(validation_results) | |
| def statistics(df): | |
| profile = pp.ProfileReport(df) | |
| report_dict = profile.get_description() | |
| description, alerts = report_dict.table, report_dict.alerts | |
| # Statistics | |
| mapping = { | |
| 'n': 'Number of observations', | |
| 'n_var': 'Number of variables', | |
| 'n_cells_missing': 'Number of cells missing', | |
| 'n_vars_with_missing': 'Number of columns with missing data', | |
| 'n_vars_all_missing': 'Columns with all missing data', | |
| 'p_cells_missing': 'Missing cells (%)', | |
| 'n_duplicates': 'Duplicated rows', | |
| 'p_duplicates': 'Duplicated rows (%)', | |
| } | |
| updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} | |
| # Add flattened types information | |
| if 'Text' in description.get('types', {}): | |
| updated_data['Number of text columns'] = description['types']['Text'] | |
| if 'Categorical' in description.get('types', {}): | |
| updated_data['Number of categorical columns'] = description['types']['Categorical'] | |
| if 'Numeric' in description.get('types', {}): | |
| updated_data['Number of numeric columns'] = description['types']['Numeric'] | |
| if 'DateTime' in description.get('types', {}): | |
| updated_data['Number of datetime columns'] = description['types']['DateTime'] | |
| df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) | |
| df_statistics['Value'] = df_statistics['Value'].astype(int) | |
| # Alerts | |
| alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] | |
| df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) | |
| return df_statistics, df_alerts | |
| #--------------------------------------- | |
| # Main Function | |
| def main(table): | |
| schema = get_table_schema(table) | |
| # Create dlt pipeline | |
| table_name = create_pipeline(schema) | |
| # Load dlt pipeline | |
| connection, df = load_pipeline(table_name) | |
| # df = get_data_df(schema) | |
| df_statistics, df_alerts = statistics(df) | |
| describe_num, describe_cat = describe(df) | |
| messages = format_prompt(df=df) | |
| tests = run_llm(messages) | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| connection.close() | |
| return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results | |
| def user_results(table, text_query): | |
| schema = get_table_schema(table) | |
| # Create dlt pipeline | |
| table_name = create_pipeline(schema) | |
| # Load dlt pipeline | |
| connection, df = load_pipeline(table_name) | |
| messages = format_user_prompt(df=df, user_description=text_query) | |
| print(f'Generated Tests from user input: {tests}') | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| connection.close() | |
| return tests_df, pandera_results | |
| # Custom CSS styling | |
| custom_css = """ | |
| print('Validated Tests with Pandera') | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>Dataset Test Workflow</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
| tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
| with gr.Row(): | |
| generate_result = gr.Button("Validate Data", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("Description"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) | |
| with gr.Column(): | |
| describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) | |
| with gr.Tab("Alerts"): | |
| data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) | |
| with gr.Tab("Rules & Validations"): | |
| tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| with gr.Tab("Data"): | |
| result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) | |
| with gr.Tab('Text to Validation'): | |
| with gr.Row(): | |
| query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pass | |
| with gr.Column(scale=1, min_width=50): | |
| user_generate_result = gr.Button("Validate Data", variant="primary" ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| with gr.Column(): | |
| query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) | |
| generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) | |
| user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result]) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |