import gradio as gr import torch import numpy as np import pandas as pd import os import tempfile import json import base64 from io import BytesIO from pathlib import Path import pickle import joblib from collections import defaultdict from rdkit import Chem from rdkit.Chem import Draw, AllChem import plotly.graph_objects as go import plotly.express as px from sklearn.decomposition import PCA from huggingface_hub import hf_hub_download # Import torch_molecule models try: from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor import torch_molecule TORCH_MOLECULE_AVAILABLE = True print('torch_molecule version: ', torch_molecule.__version__) except ImportError: TORCH_MOLECULE_AVAILABLE = False print("Warning: torch_molecule not available. Some models may not work.") # ============= PROPERTY CONFIGURATION ============= # Load property mapping with open('property_mapping.json', 'r') as f: PROPERTY_MAPPING = json.load(f) # Filter out Gas Transport Properties ALL_PROPERTIES = {} PROPERTY_CATEGORIES = defaultdict(list) for prop_abbr, prop_info in PROPERTY_MAPPING.items(): if prop_info['category'] != 'Gas Transport Properties': ALL_PROPERTIES[prop_abbr] = prop_info PROPERTY_CATEGORIES[prop_info['category']].append(prop_abbr) # Get just permeability properties for plotting PERMEABILITY_PROPERTIES = [p for p in ALL_PROPERTIES.keys() if ALL_PROPERTIES[p]['category'] == 'Permeability Properties'] print(f"Loaded {len(ALL_PROPERTIES)} properties (excluding Gas Transport Properties)") print(f"Categories: {list(PROPERTY_CATEGORIES.keys())}") all_model_names = ['GREA', 'GCN', 'GIN', 'RandomForest', 'GaussianProcess'] # Training configuration TRAIN_IN_LOG = True HF_REPO_ID = "liuganghuggingface/polymer-prediction-gas-models" # Default SMILES for testing DEFAULT_SMILES = """*c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3 *CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21 *C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1""" # Selectivity boundary parameters SELECTIVITY_BOUNDS = { 'CO2/CH4': { 'x': [1.00E+05, 1.00E-02], 'y': [1.00E+05/2.21E+04, 1.00E-02/4.88E-06], 'gases': ('CO2', 'CH4') }, 'H2/CH4': { 'x': [5.00E+04, 2.50E+00], 'y': [5.00E+04/8.67E+04, 2.50E+00/5.64E-04], 'gases': ('H2', 'CH4') }, 'O2/N2': { 'x': [5.00E+04, 1.00E-03], 'y': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05], 'gases': ('O2', 'N2') }, 'H2/N2': { 'x': [1.00E+05, 1.00E-01], 'y': [1.00E+05/1.02E+05, 1.00E-01/9.21E-06], 'gases': ('H2', 'N2') }, 'CO2/N2': { 'x': [1.00E+06, 1.00E-04], 'y': [1.00E+06/3.05E+05, 1.00E-04/1.05E-08], 'gases': ('CO2', 'N2') } } # ============= UTILITY FUNCTIONS ============= def smiles_to_mol_image(smiles, img_size=(200, 200)): """Convert SMILES to molecule image as base64 string.""" mol = Chem.MolFromSmiles(smiles) if mol is None: return None img = Draw.MolToImage(mol, size=img_size) buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" def validate_smiles(smiles_list): """Validate a list of SMILES strings.""" valid_smiles = [] invalid_smiles = [] for idx, smiles in enumerate(smiles_list): smiles = smiles.strip() if not smiles: continue mol = Chem.MolFromSmiles(smiles) if mol is not None: standardized = Chem.MolToSmiles(mol, isomericSmiles=True) valid_smiles.append((idx, smiles, standardized)) else: invalid_smiles.append((idx, smiles)) report = f"✅ Valid SMILES: {len(valid_smiles)}\n" report += f"❌ Invalid SMILES: {len(invalid_smiles)}\n" if invalid_smiles: report += "\n**Invalid SMILES detected:**\n" for idx, smiles in invalid_smiles: report += f" - Line {idx + 1}: `{smiles}`\n" report += "\n⚠️ **Please remove or correct the invalid SMILES before proceeding.**" return valid_smiles, invalid_smiles, report def smiles_to_fingerprint(smiles_list, n_bits=2048): """Convert SMILES to Morgan fingerprints.""" fingerprints = [] for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is not None: fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits) fingerprints.append(np.array(fp)) else: fingerprints.append(np.zeros(n_bits)) return np.array(fingerprints) # ============= MODEL LOADING ============= def load_all_models(): """Load all available models from HuggingFace Hub.""" print("Loading models from HuggingFace Hub (CPU only)...") loaded_models = {} device = torch.device('cpu') # Load models for ALL properties (not just permeability) for model_name in all_model_names: loaded_models[model_name] = {} # Iterate through all properties except Gas Transport Properties for prop_abbr in ALL_PROPERTIES.keys(): model_filename = f"{model_name.lower()}_{prop_abbr.lower()}" try: if model_name in ['GREA', 'GCN', 'GIN']: filename = f"{model_filename}.pt" if not TORCH_MOLECULE_AVAILABLE: continue print(f" Downloading {filename}...") model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) if model_name == 'GREA': model = GREAMolecularPredictor(device='cpu') elif model_name == 'GCN': model = GNNMolecularPredictor(gnn_type='gcn-virtual', device='cpu') elif model_name == 'GIN': model = GNNMolecularPredictor(gnn_type='gin-virtual', device='cpu') model.load_from_local(model_path) loaded_models[model_name][prop_abbr] = (model, 'torch_molecule') print(f" ✓ Loaded {model_name} for {prop_abbr}") else: # sklearn models filename = f"{model_filename}.pkl" print(f" Downloading {filename}...") model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) model = joblib.load(model_path) loaded_models[model_name][prop_abbr] = (model, 'sklearn') print(f" ✓ Loaded {model_name} for {prop_abbr}") except Exception as e: print(f" ❌ Error loading {model_name} for {prop_abbr}: {e}") print("Model loading complete!") return loaded_models PRELOADED_MODELS = load_all_models() # ============= PREDICTION FUNCTIONS ============= def predict_properties(smiles_list, selected_models, progress=gr.Progress()): """Predict properties for a list of SMILES.""" if not selected_models: return None, "❌ Please select at least one model." progress(0.1, desc="Validating SMILES...") valid_smiles, invalid_smiles, validation_report = validate_smiles(smiles_list) if invalid_smiles: return None, validation_report if not valid_smiles: return None, "❌ No SMILES provided." indices, original_smiles, standardized_smiles = zip(*valid_smiles) all_predictions = { 'original_smiles': list(original_smiles), 'standardized_smiles': list(standardized_smiles), 'predictions': {}, 'predictions_log': {} } # Prepare fingerprints for sklearn models X_fp = None needs_fingerprints = any(model in selected_models for model in ['RandomForest', 'GaussianProcess']) if needs_fingerprints: progress(0.2, desc="Computing fingerprints...") X_fp = smiles_to_fingerprint(standardized_smiles) model_errors = [] # Predict ALL properties (not just permeability) available_props = list(ALL_PROPERTIES.keys()) total_predictions = len(available_props) * len(selected_models) pred_count = 0 for model_name in selected_models: all_predictions['predictions'][model_name] = {} all_predictions['predictions_log'][model_name] = {} for prop in available_props: progress(0.2 + 0.7 * pred_count / total_predictions, desc=f"Predicting {prop} with {model_name}...") if model_name not in PRELOADED_MODELS or prop not in PRELOADED_MODELS[model_name]: model_errors.append(f"{model_name} for {prop}") pred_count += 1 continue model, model_type = PRELOADED_MODELS[model_name][prop] try: if model_type == 'torch_molecule': with torch.no_grad(): predictions_dict = model.predict(list(standardized_smiles)) predictions = predictions_dict['prediction'] else: predictions = model.predict(X_fp) if isinstance(predictions, np.ndarray) and predictions.ndim > 1: predictions = predictions.flatten() # Check if this property was trained in log scale (only for certain properties) prop_category = ALL_PROPERTIES[prop]['category'] if prop_category == 'Permeability Properties' and TRAIN_IN_LOG: predictions_original = 10**predictions all_predictions['predictions'][model_name][prop] = predictions_original all_predictions['predictions_log'][model_name][prop] = predictions else: # For other properties, store as-is all_predictions['predictions'][model_name][prop] = predictions # Log transform for potential log-scale visualizations all_predictions['predictions_log'][model_name][prop] = np.log10(np.maximum(np.abs(predictions), 1e-10)) except Exception as e: print(f"Error predicting {model_name} for {prop}: {e}") model_errors.append(f"{model_name} for {prop}") pred_count += 1 # Calculate averages progress(0.9, desc="Computing averages...") all_predictions['predictions']['Average'] = {} all_predictions['predictions_log']['Average'] = {} for prop in available_props: prop_predictions = [] prop_predictions_log = [] for model_name in selected_models: if model_name in all_predictions['predictions'] and prop in all_predictions['predictions'][model_name]: prop_predictions.append(all_predictions['predictions'][model_name][prop]) prop_predictions_log.append(all_predictions['predictions_log'][model_name][prop]) if prop_predictions: if len(prop_predictions) > 1: all_predictions['predictions']['Average'][prop] = np.mean(np.array(prop_predictions), axis=0) all_predictions['predictions_log']['Average'][prop] = np.mean(np.array(prop_predictions_log), axis=0) else: all_predictions['predictions']['Average'][prop] = prop_predictions[0] all_predictions['predictions_log']['Average'][prop] = prop_predictions_log[0] report = validation_report + "\n" if model_errors: report += f"\n⚠️ Some models were not available: {len(set(model_errors))} property-model combinations\n" report += f"\n✅ Successfully predicted properties for {len(valid_smiles)} molecules using {len(selected_models)} model(s)." report += f"\n📊 Attempted to predict {len(available_props)} properties across all categories." report += f"\n💻 All predictions run on CPU." progress(1.0, desc="Done!") return all_predictions, report # ============= VISUALIZATION FUNCTIONS ============= def create_results_gallery_html(all_predictions, selected_view='Average', selected_category='All', max_display=30): """ Create an HTML gallery showing top molecules with their structures grouped by category. """ if all_predictions is None: return "

No data available.

" num_molecules = len(all_predictions['original_smiles']) display_count = min(num_molecules, max_display) # Get predictions for selected view if selected_view not in all_predictions['predictions']: return "

No predictions available for selected view.

" predictions = all_predictions['predictions'][selected_view] # Filter properties by category if selected_category == 'All': display_categories = PROPERTY_CATEGORIES else: display_categories = {selected_category: PROPERTY_CATEGORIES[selected_category]} html = f"""

Prediction Results - {selected_category}

Showing top {display_count} of {num_molecules} molecules.

{f"

⬇️ Download the CSV file below to see all {num_molecules} results.

" if num_molecules > max_display else ""} """ # Display each molecule for idx in range(display_count): smiles = all_predictions['original_smiles'][idx] mol_img = smiles_to_mol_image(smiles, img_size=(250, 250)) html += f"""
{"" if mol_img else "

Invalid structure

"}

Molecule {idx + 1}

SMILES: {smiles[:100]}{"..." if len(smiles) > 100 else ""}

""" # Display properties by category - always show all properties for category, props in display_categories.items(): category_props = [p for p in props if p in predictions] if not category_props: continue html += f"""
{category}: """ # Display all properties for this category for prop in category_props: if prop in predictions: value = predictions[prop][idx] prop_info = ALL_PROPERTIES[prop] html += f"""
• {prop_info['full_name']} ({prop_info['unit']}): {value:.3f}
""" html += "
" html += """
""" html += "
" return html def format_full_predictions_csv(all_predictions, selected_view='Average'): """Format all predictions into CSV for download.""" if all_predictions is None: return None df = pd.DataFrame({ 'SMILES': all_predictions['original_smiles'] }) predictions = all_predictions['predictions'][selected_view] # Add predictions grouped by category for category, props in PROPERTY_CATEGORIES.items(): for prop in props: if prop in predictions: prop_info = ALL_PROPERTIES[prop] col_name = f"{prop_info['full_name']} ({prop_info['unit']})" df[col_name] = predictions[prop] return df def create_selectivity_plot_with_images(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4', max_display=30): """Create selectivity plot with molecule images on hover.""" if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS: return None bounds = SELECTIVITY_BOUNDS[selectivity_pair] gas1, gas2 = bounds['gases'] if selected_view not in all_predictions['predictions_log']: return None if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]: return None # Get predictions (limit to max_display) num_molecules = len(all_predictions['original_smiles']) display_count = min(num_molecules, max_display) gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1][:display_count] gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2][:display_count] gas1_perm = 10**gas1_perm_log gas2_perm = 10**gas2_perm_log gas1_perm = np.maximum(gas1_perm, 1e-10) gas2_perm = np.maximum(gas2_perm, 1e-10) selectivity = gas1_perm / gas2_perm # Create boundary line x1, x2 = bounds['x'] y1, y2 = bounds['y'] fig = go.Figure() # Add 2008 upper bound fig.add_trace(go.Scatter( x=[x1, x2], y=[y1, y2], mode='lines', name='2008 Upper Bound', line=dict(color='red', width=3, dash='dash'), hoverinfo='name' )) # Determine above/below bound x_log = np.log10(gas1_perm) y_log = np.log10(selectivity) x1_log, x2_log = np.log10(x1), np.log10(x2) y1_log, y2_log = np.log10(y1), np.log10(y2) a = (y1_log - y2_log) / (x1_log - x2_log) b = y1_log - a * x1_log y_bound = a * x_log + b above_bound = y_log > y_bound # Create hover texts with molecule info smiles_list = all_predictions['original_smiles'][:display_count] hover_texts = [] for i, smiles in enumerate(smiles_list): truncated = smiles if len(smiles) <= 80 else smiles[:77] + '...' status = "Above Bound" if above_bound[i] else "Below Bound" hover_text = (f"SMILES: {truncated}
" f"{gas1}: {gas1_perm[i]:.3f} Barrer
" f"{gas2}: {gas2_perm[i]:.3f} Barrer
" f"Selectivity: {selectivity[i]:.3f}
" f"Status: {status}") hover_texts.append(hover_text) # Add points if np.any(above_bound): fig.add_trace(go.Scatter( x=gas1_perm[above_bound], y=selectivity[above_bound], mode='markers', name='Above Bound', marker=dict(color='green', size=10, symbol='circle'), text=[hover_texts[i] for i in range(len(hover_texts)) if above_bound[i]], hovertemplate='%{text}' )) if np.any(~above_bound): fig.add_trace(go.Scatter( x=gas1_perm[~above_bound], y=selectivity[~above_bound], mode='markers', name='Below Bound', marker=dict(color='blue', size=8, symbol='circle'), text=[hover_texts[i] for i in range(len(hover_texts)) if not above_bound[i]], hovertemplate='%{text}' )) fig.update_xaxes( title=f"{gas1} Permeability (Barrer)", type="log", gridcolor='lightgray' ) fig.update_yaxes( title=f"{gas1}/{gas2} Selectivity", type="log", gridcolor='lightgray' ) fig.update_layout( title=f"{gas1}/{gas2} Selectivity Plot (Top {display_count} molecules)", hovermode='closest', showlegend=True, plot_bgcolor='white', height=600 ) return fig def generate_all_selectivity_plots(all_predictions, selected_view='Average', max_display=30): """Generate all selectivity plots at once.""" if all_predictions is None: return {} all_plots = {} for selectivity_pair in SELECTIVITY_BOUNDS.keys(): plot = create_selectivity_plot_with_images(all_predictions, selected_view, selectivity_pair, max_display) if plot is not None: all_plots[selectivity_pair] = plot return all_plots def create_pca_plot(all_predictions, selected_view='Average', max_display=30): """Create PCA plot with all properties on hover.""" if all_predictions is None: return None num_molecules = len(all_predictions['original_smiles']) display_count = min(num_molecules, max_display) smiles_list = all_predictions['standardized_smiles'][:display_count] # Compute fingerprints X_fp = smiles_to_fingerprint(smiles_list) # Perform PCA pca = PCA(n_components=2) X_pca = pca.fit_transform(X_fp) # Create hover texts with all properties predictions = all_predictions['predictions'][selected_view] hover_texts = [] for idx in range(display_count): smiles = all_predictions['original_smiles'][idx] truncated = smiles if len(smiles) <= 60 else smiles[:57] + '...' hover_text = f"SMILES: {truncated}
" for category, props in PROPERTY_CATEGORIES.items(): hover_text += f"
{category}:
" for prop in props: if prop in predictions: value = predictions[prop][idx] prop_info = ALL_PROPERTIES[prop] hover_text += f" {prop_info['full_name']}: {value:.3f} {prop_info['unit']}
" hover_texts.append(hover_text) # Create plot fig = go.Figure() fig.add_trace(go.Scatter( x=X_pca[:, 0], y=X_pca[:, 1], mode='markers', marker=dict( size=10, color=np.arange(display_count), colorscale='Viridis', showscale=True, colorbar=dict(title="Molecule #") ), text=hover_texts, hovertemplate='%{text}' )) fig.update_layout( title=f"PCA Visualization of Molecular Structures (Top {display_count} molecules)", xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)", yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)", hovermode='closest', plot_bgcolor='white', height=600 ) return fig # ============= GRADIO INTERFACE ============= available_models = [m for m in all_model_names if m in PRELOADED_MODELS and PRELOADED_MODELS[m]] if not available_models: print("⚠️ WARNING: No models loaded!") available_models = all_model_names with gr.Blocks(title="Polymer Property Prediction", theme=gr.themes.Soft()) as iface: gr.Markdown("""

🔬 Polymer Property Prediction

Predict electronic, dielectric & optical, thermal, physical & thermodynamic, and gas permeability properties

💻 Powered by torch-molecule & sklearn
""") # Input section - more compact layout with gr.Row(): with gr.Column(scale=2): gr.Markdown("### 📝 Input SMILES") smiles_text = gr.Textbox( label="Enter SMILES (one per line)", placeholder="Enter polymer SMILES strings, one per line...", lines=8, value=DEFAULT_SMILES ) smiles_file = gr.File( label="Or upload a file (.txt, .csv, .smi)", file_types=[".txt", ".csv", ".smi"] ) with gr.Column(scale=1): gr.Markdown("### ⚙️ Model Selection") model_selector = gr.CheckboxGroup( choices=available_models, label="Select Models for Ensemble Prediction", value=[available_models[0]] if available_models else [], info="Choose one or more models. Multiple models will be averaged for robust predictions." ) predict_btn = gr.Button("🔮 Predict Properties", variant="primary", size="lg") prediction_status = gr.Textbox(label="📊 Status", lines=4, show_label=True) # Results section gr.Markdown("""

📊 Prediction Results

Results include: Electronic (bandgap, ionization energy), Dielectric & Optical (refractive index), Thermal (Tg, Tm, conductivity), Physical (density, FFV, radius of gyration), and Gas Permeability properties

""") with gr.Row(): view_selector = gr.Radio( choices=['Average'], label="Select Model View", value='Average', visible=False ) category_selector = gr.Radio( choices=['All'] + list(PROPERTY_CATEGORIES.keys()), label="Select Property Category", value='All', info="Choose a category to view specific properties" ) results_html = gr.HTML(label="Results with Molecular Structures") download_btn = gr.DownloadButton( label="📥 Download All Results (CSV)", visible=False ) # Visualization section gr.Markdown("""

📈 Interactive Visualizations

Limited to top 30 molecules for performance

""") with gr.Tabs(): with gr.TabItem("🎯 Gas Selectivity Analysis"): gr.Markdown("""

Analyze gas separation performance against the 2008 Robeson upper bounds

""") plot_selectivity_btn = gr.Button("📊 Generate All Selectivity Plots", variant="secondary", size="lg") with gr.Row(): selectivity_pair_selector = gr.Radio( choices=list(SELECTIVITY_BOUNDS.keys()), label="Select Gas Pair to View", value='CO2/CH4', visible=False ) selectivity_plot = gr.Plot(label="Selectivity Plot") with gr.TabItem("🗺️ PCA Visualization"): gr.Markdown("""

Explore the chemical space using PCA dimensionality reduction on molecular fingerprints

""") plot_pca_btn = gr.Button("📊 Generate PCA Plot", variant="secondary") pca_plot = gr.Plot(label="PCA Plot") # Hidden state all_predictions_state = gr.State(None) # Event handlers def on_predict(text_input, file_input, selected_models): # Process input smiles_list = [] if text_input and text_input.strip(): smiles_list.extend([line.strip() for line in text_input.strip().split('\n') if line.strip()]) if file_input is not None: try: file_path = file_input if isinstance(file_input, str) else file_input.name if file_path.endswith('.csv'): df = pd.read_csv(file_input if isinstance(file_input, str) else file_input.name) if 'SMILES' in df.columns: smiles_list.extend(df['SMILES'].dropna().astype(str).tolist()) else: if isinstance(file_input, str): with open(file_input, 'r') as f: lines = f.readlines() else: content = file_input.read() if isinstance(content, bytes): content = content.decode('utf-8') lines = content.strip().split('\n') smiles_list.extend([line.strip() for line in lines if line.strip()]) except Exception as e: return None, "", f"❌ Error reading file: {str(e)}", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) if not smiles_list: return None, "", "❌ Please provide SMILES strings.", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) # Remove duplicates unique_smiles = list(dict.fromkeys(smiles_list)) # Make predictions all_predictions, report = predict_properties(unique_smiles, selected_models) if all_predictions is None: return None, "", report, gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) # Create results gallery with default "All" category results_gallery = create_results_gallery_html(all_predictions, 'Average', 'All', max_display=30) # Prepare CSV download df_full = format_full_predictions_csv(all_predictions, 'Average') temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') df_full.to_csv(temp_csv.name, index=False) temp_csv.close() # Update view selector view_options = ['Average'] + [m for m in selected_models if m in all_predictions['predictions']] view_selector_update = gr.Radio( choices=view_options, value='Average', visible=True ) category_selector_update = gr.Radio( choices=['All'] + list(PROPERTY_CATEGORIES.keys()), value='All', visible=True ) return ( all_predictions, results_gallery, report, view_selector_update, category_selector_update, gr.DownloadButton( label="📥 Download All Results (CSV)", value=temp_csv.name, visible=True ) ) def on_view_or_category_change(all_predictions, selected_view, selected_category): if all_predictions is None: return "", gr.DownloadButton(visible=False) results_gallery = create_results_gallery_html(all_predictions, selected_view, selected_category, max_display=30) df_full = format_full_predictions_csv(all_predictions, selected_view) temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') df_full.to_csv(temp_csv.name, index=False) temp_csv.close() return results_gallery, gr.DownloadButton( label=f"📥 Download {selected_view} Results (CSV)", value=temp_csv.name, visible=True ) def on_plot_selectivity(all_predictions, selected_view): """Generate all selectivity plots at once.""" if all_predictions is None: return {}, gr.Radio(visible=False), None # Generate all plots all_plots = generate_all_selectivity_plots(all_predictions, selected_view, max_display=30) if not all_plots: return {}, gr.Radio(visible=False), None # Show the first plot (CO2/CH4) first_pair = 'CO2/CH4' if 'CO2/CH4' in all_plots else list(all_plots.keys())[0] return ( all_plots, gr.Radio( choices=list(all_plots.keys()), value=first_pair, visible=True ), all_plots[first_pair] ) def on_selectivity_pair_change(all_plots_dict, selected_pair): """Switch between pre-generated selectivity plots.""" if not all_plots_dict or selected_pair not in all_plots_dict: return None return all_plots_dict[selected_pair] def on_plot_pca(all_predictions, selected_view): if all_predictions is None: return None return create_pca_plot(all_predictions, selected_view, max_display=30) # Hidden state for storing all selectivity plots all_selectivity_plots_state = gr.State({}) # Connect events predict_btn.click( on_predict, inputs=[smiles_text, smiles_file, model_selector], outputs=[all_predictions_state, results_html, prediction_status, view_selector, category_selector, download_btn] ) view_selector.change( on_view_or_category_change, inputs=[all_predictions_state, view_selector, category_selector], outputs=[results_html, download_btn] ) category_selector.change( on_view_or_category_change, inputs=[all_predictions_state, view_selector, category_selector], outputs=[results_html, download_btn] ) plot_selectivity_btn.click( on_plot_selectivity, inputs=[all_predictions_state, view_selector], outputs=[all_selectivity_plots_state, selectivity_pair_selector, selectivity_plot] ) selectivity_pair_selector.change( on_selectivity_pair_change, inputs=[all_selectivity_plots_state, selectivity_pair_selector], outputs=[selectivity_plot] ) plot_pca_btn.click( on_plot_pca, inputs=[all_predictions_state, view_selector], outputs=[pca_plot] ) if __name__ == "__main__": iface.launch(share=True)