liuganghuggingface's picture
add property mapping
e935b86
import gradio as gr
import torch
import numpy as np
import pandas as pd
import os
import tempfile
import json
import base64
from io import BytesIO
from pathlib import Path
import pickle
import joblib
from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA
from huggingface_hub import hf_hub_download
# Import torch_molecule models
try:
from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor
import torch_molecule
TORCH_MOLECULE_AVAILABLE = True
print('torch_molecule version: ', torch_molecule.__version__)
except ImportError:
TORCH_MOLECULE_AVAILABLE = False
print("Warning: torch_molecule not available. Some models may not work.")
# ============= PROPERTY CONFIGURATION =============
# Load property mapping
with open('property_mapping.json', 'r') as f:
PROPERTY_MAPPING = json.load(f)
# Filter out Gas Transport Properties
ALL_PROPERTIES = {}
PROPERTY_CATEGORIES = defaultdict(list)
for prop_abbr, prop_info in PROPERTY_MAPPING.items():
if prop_info['category'] != 'Gas Transport Properties':
ALL_PROPERTIES[prop_abbr] = prop_info
PROPERTY_CATEGORIES[prop_info['category']].append(prop_abbr)
# Get just permeability properties for plotting
PERMEABILITY_PROPERTIES = [p for p in ALL_PROPERTIES.keys() if ALL_PROPERTIES[p]['category'] == 'Permeability Properties']
print(f"Loaded {len(ALL_PROPERTIES)} properties (excluding Gas Transport Properties)")
print(f"Categories: {list(PROPERTY_CATEGORIES.keys())}")
all_model_names = ['GREA', 'GCN', 'GIN', 'RandomForest', 'GaussianProcess']
# Training configuration
TRAIN_IN_LOG = True
HF_REPO_ID = "liuganghuggingface/polymer-prediction-gas-models"
# Default SMILES for testing
DEFAULT_SMILES = """*c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3
*CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21
*C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1"""
# Selectivity boundary parameters
SELECTIVITY_BOUNDS = {
'CO2/CH4': {
'x': [1.00E+05, 1.00E-02],
'y': [1.00E+05/2.21E+04, 1.00E-02/4.88E-06],
'gases': ('CO2', 'CH4')
},
'H2/CH4': {
'x': [5.00E+04, 2.50E+00],
'y': [5.00E+04/8.67E+04, 2.50E+00/5.64E-04],
'gases': ('H2', 'CH4')
},
'O2/N2': {
'x': [5.00E+04, 1.00E-03],
'y': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05],
'gases': ('O2', 'N2')
},
'H2/N2': {
'x': [1.00E+05, 1.00E-01],
'y': [1.00E+05/1.02E+05, 1.00E-01/9.21E-06],
'gases': ('H2', 'N2')
},
'CO2/N2': {
'x': [1.00E+06, 1.00E-04],
'y': [1.00E+06/3.05E+05, 1.00E-04/1.05E-08],
'gases': ('CO2', 'N2')
}
}
# ============= UTILITY FUNCTIONS =============
def smiles_to_mol_image(smiles, img_size=(200, 200)):
"""Convert SMILES to molecule image as base64 string."""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
img = Draw.MolToImage(mol, size=img_size)
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
def validate_smiles(smiles_list):
"""Validate a list of SMILES strings."""
valid_smiles = []
invalid_smiles = []
for idx, smiles in enumerate(smiles_list):
smiles = smiles.strip()
if not smiles:
continue
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
standardized = Chem.MolToSmiles(mol, isomericSmiles=True)
valid_smiles.append((idx, smiles, standardized))
else:
invalid_smiles.append((idx, smiles))
report = f"✅ Valid SMILES: {len(valid_smiles)}\n"
report += f"❌ Invalid SMILES: {len(invalid_smiles)}\n"
if invalid_smiles:
report += "\n**Invalid SMILES detected:**\n"
for idx, smiles in invalid_smiles:
report += f" - Line {idx + 1}: `{smiles}`\n"
report += "\n⚠️ **Please remove or correct the invalid SMILES before proceeding.**"
return valid_smiles, invalid_smiles, report
def smiles_to_fingerprint(smiles_list, n_bits=2048):
"""Convert SMILES to Morgan fingerprints."""
fingerprints = []
for smiles in smiles_list:
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits)
fingerprints.append(np.array(fp))
else:
fingerprints.append(np.zeros(n_bits))
return np.array(fingerprints)
# ============= MODEL LOADING =============
def load_all_models():
"""Load all available models from HuggingFace Hub."""
print("Loading models from HuggingFace Hub (CPU only)...")
loaded_models = {}
device = torch.device('cpu')
# Load models for ALL properties (not just permeability)
for model_name in all_model_names:
loaded_models[model_name] = {}
# Iterate through all properties except Gas Transport Properties
for prop_abbr in ALL_PROPERTIES.keys():
model_filename = f"{model_name.lower()}_{prop_abbr.lower()}"
try:
if model_name in ['GREA', 'GCN', 'GIN']:
filename = f"{model_filename}.pt"
if not TORCH_MOLECULE_AVAILABLE:
continue
print(f" Downloading {filename}...")
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
if model_name == 'GREA':
model = GREAMolecularPredictor(device='cpu')
elif model_name == 'GCN':
model = GNNMolecularPredictor(gnn_type='gcn-virtual', device='cpu')
elif model_name == 'GIN':
model = GNNMolecularPredictor(gnn_type='gin-virtual', device='cpu')
model.load_from_local(model_path)
loaded_models[model_name][prop_abbr] = (model, 'torch_molecule')
print(f" ✓ Loaded {model_name} for {prop_abbr}")
else: # sklearn models
filename = f"{model_filename}.pkl"
print(f" Downloading {filename}...")
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
model = joblib.load(model_path)
loaded_models[model_name][prop_abbr] = (model, 'sklearn')
print(f" ✓ Loaded {model_name} for {prop_abbr}")
except Exception as e:
print(f" ❌ Error loading {model_name} for {prop_abbr}: {e}")
print("Model loading complete!")
return loaded_models
PRELOADED_MODELS = load_all_models()
# ============= PREDICTION FUNCTIONS =============
def predict_properties(smiles_list, selected_models, progress=gr.Progress()):
"""Predict properties for a list of SMILES."""
if not selected_models:
return None, "❌ Please select at least one model."
progress(0.1, desc="Validating SMILES...")
valid_smiles, invalid_smiles, validation_report = validate_smiles(smiles_list)
if invalid_smiles:
return None, validation_report
if not valid_smiles:
return None, "❌ No SMILES provided."
indices, original_smiles, standardized_smiles = zip(*valid_smiles)
all_predictions = {
'original_smiles': list(original_smiles),
'standardized_smiles': list(standardized_smiles),
'predictions': {},
'predictions_log': {}
}
# Prepare fingerprints for sklearn models
X_fp = None
needs_fingerprints = any(model in selected_models for model in ['RandomForest', 'GaussianProcess'])
if needs_fingerprints:
progress(0.2, desc="Computing fingerprints...")
X_fp = smiles_to_fingerprint(standardized_smiles)
model_errors = []
# Predict ALL properties (not just permeability)
available_props = list(ALL_PROPERTIES.keys())
total_predictions = len(available_props) * len(selected_models)
pred_count = 0
for model_name in selected_models:
all_predictions['predictions'][model_name] = {}
all_predictions['predictions_log'][model_name] = {}
for prop in available_props:
progress(0.2 + 0.7 * pred_count / total_predictions,
desc=f"Predicting {prop} with {model_name}...")
if model_name not in PRELOADED_MODELS or prop not in PRELOADED_MODELS[model_name]:
model_errors.append(f"{model_name} for {prop}")
pred_count += 1
continue
model, model_type = PRELOADED_MODELS[model_name][prop]
try:
if model_type == 'torch_molecule':
with torch.no_grad():
predictions_dict = model.predict(list(standardized_smiles))
predictions = predictions_dict['prediction']
else:
predictions = model.predict(X_fp)
if isinstance(predictions, np.ndarray) and predictions.ndim > 1:
predictions = predictions.flatten()
# Check if this property was trained in log scale (only for certain properties)
prop_category = ALL_PROPERTIES[prop]['category']
if prop_category == 'Permeability Properties' and TRAIN_IN_LOG:
predictions_original = 10**predictions
all_predictions['predictions'][model_name][prop] = predictions_original
all_predictions['predictions_log'][model_name][prop] = predictions
else:
# For other properties, store as-is
all_predictions['predictions'][model_name][prop] = predictions
# Log transform for potential log-scale visualizations
all_predictions['predictions_log'][model_name][prop] = np.log10(np.maximum(np.abs(predictions), 1e-10))
except Exception as e:
print(f"Error predicting {model_name} for {prop}: {e}")
model_errors.append(f"{model_name} for {prop}")
pred_count += 1
# Calculate averages
progress(0.9, desc="Computing averages...")
all_predictions['predictions']['Average'] = {}
all_predictions['predictions_log']['Average'] = {}
for prop in available_props:
prop_predictions = []
prop_predictions_log = []
for model_name in selected_models:
if model_name in all_predictions['predictions'] and prop in all_predictions['predictions'][model_name]:
prop_predictions.append(all_predictions['predictions'][model_name][prop])
prop_predictions_log.append(all_predictions['predictions_log'][model_name][prop])
if prop_predictions:
if len(prop_predictions) > 1:
all_predictions['predictions']['Average'][prop] = np.mean(np.array(prop_predictions), axis=0)
all_predictions['predictions_log']['Average'][prop] = np.mean(np.array(prop_predictions_log), axis=0)
else:
all_predictions['predictions']['Average'][prop] = prop_predictions[0]
all_predictions['predictions_log']['Average'][prop] = prop_predictions_log[0]
report = validation_report + "\n"
if model_errors:
report += f"\n⚠️ Some models were not available: {len(set(model_errors))} property-model combinations\n"
report += f"\n✅ Successfully predicted properties for {len(valid_smiles)} molecules using {len(selected_models)} model(s)."
report += f"\n📊 Attempted to predict {len(available_props)} properties across all categories."
report += f"\n💻 All predictions run on CPU."
progress(1.0, desc="Done!")
return all_predictions, report
# ============= VISUALIZATION FUNCTIONS =============
def create_results_gallery_html(all_predictions, selected_view='Average', selected_category='All', max_display=30):
"""
Create an HTML gallery showing top molecules with their structures grouped by category.
"""
if all_predictions is None:
return "<p>No data available.</p>"
num_molecules = len(all_predictions['original_smiles'])
display_count = min(num_molecules, max_display)
# Get predictions for selected view
if selected_view not in all_predictions['predictions']:
return "<p>No predictions available for selected view.</p>"
predictions = all_predictions['predictions'][selected_view]
# Filter properties by category
if selected_category == 'All':
display_categories = PROPERTY_CATEGORIES
else:
display_categories = {selected_category: PROPERTY_CATEGORIES[selected_category]}
html = f"""
<div style='font-family: Arial, sans-serif; color: #000;'>
<h3 style='color: #1a1a1a; margin-bottom: 10px;'>Prediction Results - {selected_category}</h3>
<p style='color: #333; font-size: 1.1em; font-weight: 500;'><strong>Showing top {display_count} of {num_molecules} molecules.</strong></p>
{f"<p style='color: #d97706; font-weight: 500; font-size: 1.05em;'>⬇️ Download the CSV file below to see all {num_molecules} results.</p>" if num_molecules > max_display else ""}
"""
# Display each molecule
for idx in range(display_count):
smiles = all_predictions['original_smiles'][idx]
mol_img = smiles_to_mol_image(smiles, img_size=(250, 250))
html += f"""
<div style='border: 2px solid #cbd5e1; border-radius: 12px; padding: 20px; margin: 20px 0; background: #ffffff; box-shadow: 0 2px 8px rgba(0,0,0,0.1);'>
<div style='display: flex; gap: 25px; align-items: flex-start;'>
<div style='flex-shrink: 0; background: #f8fafc; padding: 10px; border-radius: 8px;'>
{"<img src='" + mol_img + "' style='width: 250px; height: 250px; display: block;'/>" if mol_img else "<p style='color: #ef4444; font-weight: 500;'>Invalid structure</p>"}
</div>
<div style='flex-grow: 1;'>
<h4 style='color: #0f172a; margin: 0 0 15px 0; font-size: 1.3em; border-bottom: 2px solid #e2e8f0; padding-bottom: 8px;'>Molecule {idx + 1}</h4>
<p style='color: #475569; margin-bottom: 15px; background: #f1f5f9; padding: 10px; border-radius: 6px; word-break: break-all;'>
<strong style='color: #1e293b;'>SMILES:</strong>
<code style='background: #e2e8f0; padding: 4px 8px; border-radius: 4px; font-size: 0.9em; color: #334155;'>{smiles[:100]}{"..." if len(smiles) > 100 else ""}</code>
</p>
"""
# Display properties by category - always show all properties
for category, props in display_categories.items():
category_props = [p for p in props if p in predictions]
if not category_props:
continue
html += f"""
<div style='margin-top: 15px; background: #f8fafc; padding: 12px; border-radius: 8px; border-left: 4px solid #3b82f6;'>
<strong style='color: #1e40af; font-size: 1.1em; display: block; margin-bottom: 8px;'>{category}:</strong>
"""
# Display all properties for this category
for prop in category_props:
if prop in predictions:
value = predictions[prop][idx]
prop_info = ALL_PROPERTIES[prop]
html += f"""
<div style='margin: 6px 0; padding: 6px 10px; background: #ffffff; border-radius: 4px;'>
<span style='color: #334155; font-weight: 500;'>• {prop_info['full_name']}</span>
<span style='color: #64748b;'>({prop_info['unit']}):</span>
<span style='color: #0369a1; font-weight: bold; font-size: 1.05em;'>{value:.3f}</span>
</div>
"""
html += "</div>"
html += """
</div>
</div>
</div>
"""
html += "</div>"
return html
def format_full_predictions_csv(all_predictions, selected_view='Average'):
"""Format all predictions into CSV for download."""
if all_predictions is None:
return None
df = pd.DataFrame({
'SMILES': all_predictions['original_smiles']
})
predictions = all_predictions['predictions'][selected_view]
# Add predictions grouped by category
for category, props in PROPERTY_CATEGORIES.items():
for prop in props:
if prop in predictions:
prop_info = ALL_PROPERTIES[prop]
col_name = f"{prop_info['full_name']} ({prop_info['unit']})"
df[col_name] = predictions[prop]
return df
def create_selectivity_plot_with_images(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4', max_display=30):
"""Create selectivity plot with molecule images on hover."""
if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS:
return None
bounds = SELECTIVITY_BOUNDS[selectivity_pair]
gas1, gas2 = bounds['gases']
if selected_view not in all_predictions['predictions_log']:
return None
if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]:
return None
# Get predictions (limit to max_display)
num_molecules = len(all_predictions['original_smiles'])
display_count = min(num_molecules, max_display)
gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1][:display_count]
gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2][:display_count]
gas1_perm = 10**gas1_perm_log
gas2_perm = 10**gas2_perm_log
gas1_perm = np.maximum(gas1_perm, 1e-10)
gas2_perm = np.maximum(gas2_perm, 1e-10)
selectivity = gas1_perm / gas2_perm
# Create boundary line
x1, x2 = bounds['x']
y1, y2 = bounds['y']
fig = go.Figure()
# Add 2008 upper bound
fig.add_trace(go.Scatter(
x=[x1, x2],
y=[y1, y2],
mode='lines',
name='2008 Upper Bound',
line=dict(color='red', width=3, dash='dash'),
hoverinfo='name'
))
# Determine above/below bound
x_log = np.log10(gas1_perm)
y_log = np.log10(selectivity)
x1_log, x2_log = np.log10(x1), np.log10(x2)
y1_log, y2_log = np.log10(y1), np.log10(y2)
a = (y1_log - y2_log) / (x1_log - x2_log)
b = y1_log - a * x1_log
y_bound = a * x_log + b
above_bound = y_log > y_bound
# Create hover texts with molecule info
smiles_list = all_predictions['original_smiles'][:display_count]
hover_texts = []
for i, smiles in enumerate(smiles_list):
truncated = smiles if len(smiles) <= 80 else smiles[:77] + '...'
status = "Above Bound" if above_bound[i] else "Below Bound"
hover_text = (f"SMILES: {truncated}<br>"
f"{gas1}: {gas1_perm[i]:.3f} Barrer<br>"
f"{gas2}: {gas2_perm[i]:.3f} Barrer<br>"
f"Selectivity: {selectivity[i]:.3f}<br>"
f"Status: {status}")
hover_texts.append(hover_text)
# Add points
if np.any(above_bound):
fig.add_trace(go.Scatter(
x=gas1_perm[above_bound],
y=selectivity[above_bound],
mode='markers',
name='Above Bound',
marker=dict(color='green', size=10, symbol='circle'),
text=[hover_texts[i] for i in range(len(hover_texts)) if above_bound[i]],
hovertemplate='%{text}<extra></extra>'
))
if np.any(~above_bound):
fig.add_trace(go.Scatter(
x=gas1_perm[~above_bound],
y=selectivity[~above_bound],
mode='markers',
name='Below Bound',
marker=dict(color='blue', size=8, symbol='circle'),
text=[hover_texts[i] for i in range(len(hover_texts)) if not above_bound[i]],
hovertemplate='%{text}<extra></extra>'
))
fig.update_xaxes(
title=f"{gas1} Permeability (Barrer)",
type="log",
gridcolor='lightgray'
)
fig.update_yaxes(
title=f"{gas1}/{gas2} Selectivity",
type="log",
gridcolor='lightgray'
)
fig.update_layout(
title=f"{gas1}/{gas2} Selectivity Plot (Top {display_count} molecules)",
hovermode='closest',
showlegend=True,
plot_bgcolor='white',
height=600
)
return fig
def generate_all_selectivity_plots(all_predictions, selected_view='Average', max_display=30):
"""Generate all selectivity plots at once."""
if all_predictions is None:
return {}
all_plots = {}
for selectivity_pair in SELECTIVITY_BOUNDS.keys():
plot = create_selectivity_plot_with_images(all_predictions, selected_view, selectivity_pair, max_display)
if plot is not None:
all_plots[selectivity_pair] = plot
return all_plots
def create_pca_plot(all_predictions, selected_view='Average', max_display=30):
"""Create PCA plot with all properties on hover."""
if all_predictions is None:
return None
num_molecules = len(all_predictions['original_smiles'])
display_count = min(num_molecules, max_display)
smiles_list = all_predictions['standardized_smiles'][:display_count]
# Compute fingerprints
X_fp = smiles_to_fingerprint(smiles_list)
# Perform PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_fp)
# Create hover texts with all properties
predictions = all_predictions['predictions'][selected_view]
hover_texts = []
for idx in range(display_count):
smiles = all_predictions['original_smiles'][idx]
truncated = smiles if len(smiles) <= 60 else smiles[:57] + '...'
hover_text = f"SMILES: {truncated}<br>"
for category, props in PROPERTY_CATEGORIES.items():
hover_text += f"<br><b>{category}:</b><br>"
for prop in props:
if prop in predictions:
value = predictions[prop][idx]
prop_info = ALL_PROPERTIES[prop]
hover_text += f" {prop_info['full_name']}: {value:.3f} {prop_info['unit']}<br>"
hover_texts.append(hover_text)
# Create plot
fig = go.Figure()
fig.add_trace(go.Scatter(
x=X_pca[:, 0],
y=X_pca[:, 1],
mode='markers',
marker=dict(
size=10,
color=np.arange(display_count),
colorscale='Viridis',
showscale=True,
colorbar=dict(title="Molecule #")
),
text=hover_texts,
hovertemplate='%{text}<extra></extra>'
))
fig.update_layout(
title=f"PCA Visualization of Molecular Structures (Top {display_count} molecules)",
xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)",
yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)",
hovermode='closest',
plot_bgcolor='white',
height=600
)
return fig
# ============= GRADIO INTERFACE =============
available_models = [m for m in all_model_names if m in PRELOADED_MODELS and PRELOADED_MODELS[m]]
if not available_models:
print("⚠️ WARNING: No models loaded!")
available_models = all_model_names
with gr.Blocks(title="Polymer Property Prediction", theme=gr.themes.Soft()) as iface:
gr.Markdown("""
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; margin-bottom: 20px;">
<h1 style="color: white; margin: 0; font-size: 2.5em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);">🔬 Polymer Property Prediction</h1>
<p style="font-size: 1.2em; color: #f0f0f0; margin: 10px 0 0 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.3);">Predict electronic, dielectric & optical, thermal, physical & thermodynamic, and gas permeability properties</p>
<div style="margin-top: 15px;">
<a href="https://github.com/liugangcode/torch-molecule" target="_blank"
style="color: #fff; text-decoration: none; background: rgba(255,255,255,0.2); padding: 10px 20px; border-radius: 20px; font-weight: 500; backdrop-filter: blur(10px);">
💻 Powered by torch-molecule & sklearn
</a>
</div>
</div>
""")
# Input section - more compact layout
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 📝 Input SMILES")
smiles_text = gr.Textbox(
label="Enter SMILES (one per line)",
placeholder="Enter polymer SMILES strings, one per line...",
lines=8,
value=DEFAULT_SMILES
)
smiles_file = gr.File(
label="Or upload a file (.txt, .csv, .smi)",
file_types=[".txt", ".csv", ".smi"]
)
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Model Selection")
model_selector = gr.CheckboxGroup(
choices=available_models,
label="Select Models for Ensemble Prediction",
value=[available_models[0]] if available_models else [],
info="Choose one or more models. Multiple models will be averaged for robust predictions."
)
predict_btn = gr.Button("🔮 Predict Properties", variant="primary", size="lg")
prediction_status = gr.Textbox(label="📊 Status", lines=4, show_label=True)
# Results section
gr.Markdown("""
<div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); border-radius: 10px;'>
<h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📊 Prediction Results</h2>
<p style='color: #fef3c7; margin: 5px 0 0 0; font-size: 0.9em;'>Results include: Electronic (bandgap, ionization energy), Dielectric & Optical (refractive index), Thermal (Tg, Tm, conductivity), Physical (density, FFV, radius of gyration), and Gas Permeability properties</p>
</div>
""")
with gr.Row():
view_selector = gr.Radio(
choices=['Average'],
label="Select Model View",
value='Average',
visible=False
)
category_selector = gr.Radio(
choices=['All'] + list(PROPERTY_CATEGORIES.keys()),
label="Select Property Category",
value='All',
info="Choose a category to view specific properties"
)
results_html = gr.HTML(label="Results with Molecular Structures")
download_btn = gr.DownloadButton(
label="📥 Download All Results (CSV)",
visible=False
)
# Visualization section
gr.Markdown("""
<div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); border-radius: 10px;'>
<h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📈 Interactive Visualizations</h2>
<p style='color: #f0f9ff; margin: 5px 0 0 0; font-size: 0.95em;'>Limited to top 30 molecules for performance</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("🎯 Gas Selectivity Analysis"):
gr.Markdown("""
<div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'>
<p style='margin: 0; color: #1e40af; font-weight: 500;'>
Analyze gas separation performance against the 2008 Robeson upper bounds
</p>
</div>
""")
plot_selectivity_btn = gr.Button("📊 Generate All Selectivity Plots", variant="secondary", size="lg")
with gr.Row():
selectivity_pair_selector = gr.Radio(
choices=list(SELECTIVITY_BOUNDS.keys()),
label="Select Gas Pair to View",
value='CO2/CH4',
visible=False
)
selectivity_plot = gr.Plot(label="Selectivity Plot")
with gr.TabItem("🗺️ PCA Visualization"):
gr.Markdown("""
<div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'>
<p style='margin: 0; color: #1e40af; font-weight: 500;'>
Explore the chemical space using PCA dimensionality reduction on molecular fingerprints
</p>
</div>
""")
plot_pca_btn = gr.Button("📊 Generate PCA Plot", variant="secondary")
pca_plot = gr.Plot(label="PCA Plot")
# Hidden state
all_predictions_state = gr.State(None)
# Event handlers
def on_predict(text_input, file_input, selected_models):
# Process input
smiles_list = []
if text_input and text_input.strip():
smiles_list.extend([line.strip() for line in text_input.strip().split('\n') if line.strip()])
if file_input is not None:
try:
file_path = file_input if isinstance(file_input, str) else file_input.name
if file_path.endswith('.csv'):
df = pd.read_csv(file_input if isinstance(file_input, str) else file_input.name)
if 'SMILES' in df.columns:
smiles_list.extend(df['SMILES'].dropna().astype(str).tolist())
else:
if isinstance(file_input, str):
with open(file_input, 'r') as f:
lines = f.readlines()
else:
content = file_input.read()
if isinstance(content, bytes):
content = content.decode('utf-8')
lines = content.strip().split('\n')
smiles_list.extend([line.strip() for line in lines if line.strip()])
except Exception as e:
return None, "", f"❌ Error reading file: {str(e)}", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False)
if not smiles_list:
return None, "", "❌ Please provide SMILES strings.", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False)
# Remove duplicates
unique_smiles = list(dict.fromkeys(smiles_list))
# Make predictions
all_predictions, report = predict_properties(unique_smiles, selected_models)
if all_predictions is None:
return None, "", report, gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False)
# Create results gallery with default "All" category
results_gallery = create_results_gallery_html(all_predictions, 'Average', 'All', max_display=30)
# Prepare CSV download
df_full = format_full_predictions_csv(all_predictions, 'Average')
temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv')
df_full.to_csv(temp_csv.name, index=False)
temp_csv.close()
# Update view selector
view_options = ['Average'] + [m for m in selected_models if m in all_predictions['predictions']]
view_selector_update = gr.Radio(
choices=view_options,
value='Average',
visible=True
)
category_selector_update = gr.Radio(
choices=['All'] + list(PROPERTY_CATEGORIES.keys()),
value='All',
visible=True
)
return (
all_predictions,
results_gallery,
report,
view_selector_update,
category_selector_update,
gr.DownloadButton(
label="📥 Download All Results (CSV)",
value=temp_csv.name,
visible=True
)
)
def on_view_or_category_change(all_predictions, selected_view, selected_category):
if all_predictions is None:
return "", gr.DownloadButton(visible=False)
results_gallery = create_results_gallery_html(all_predictions, selected_view, selected_category, max_display=30)
df_full = format_full_predictions_csv(all_predictions, selected_view)
temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv')
df_full.to_csv(temp_csv.name, index=False)
temp_csv.close()
return results_gallery, gr.DownloadButton(
label=f"📥 Download {selected_view} Results (CSV)",
value=temp_csv.name,
visible=True
)
def on_plot_selectivity(all_predictions, selected_view):
"""Generate all selectivity plots at once."""
if all_predictions is None:
return {}, gr.Radio(visible=False), None
# Generate all plots
all_plots = generate_all_selectivity_plots(all_predictions, selected_view, max_display=30)
if not all_plots:
return {}, gr.Radio(visible=False), None
# Show the first plot (CO2/CH4)
first_pair = 'CO2/CH4' if 'CO2/CH4' in all_plots else list(all_plots.keys())[0]
return (
all_plots,
gr.Radio(
choices=list(all_plots.keys()),
value=first_pair,
visible=True
),
all_plots[first_pair]
)
def on_selectivity_pair_change(all_plots_dict, selected_pair):
"""Switch between pre-generated selectivity plots."""
if not all_plots_dict or selected_pair not in all_plots_dict:
return None
return all_plots_dict[selected_pair]
def on_plot_pca(all_predictions, selected_view):
if all_predictions is None:
return None
return create_pca_plot(all_predictions, selected_view, max_display=30)
# Hidden state for storing all selectivity plots
all_selectivity_plots_state = gr.State({})
# Connect events
predict_btn.click(
on_predict,
inputs=[smiles_text, smiles_file, model_selector],
outputs=[all_predictions_state, results_html, prediction_status, view_selector, category_selector, download_btn]
)
view_selector.change(
on_view_or_category_change,
inputs=[all_predictions_state, view_selector, category_selector],
outputs=[results_html, download_btn]
)
category_selector.change(
on_view_or_category_change,
inputs=[all_predictions_state, view_selector, category_selector],
outputs=[results_html, download_btn]
)
plot_selectivity_btn.click(
on_plot_selectivity,
inputs=[all_predictions_state, view_selector],
outputs=[all_selectivity_plots_state, selectivity_pair_selector, selectivity_plot]
)
selectivity_pair_selector.change(
on_selectivity_pair_change,
inputs=[all_selectivity_plots_state, selectivity_pair_selector],
outputs=[selectivity_plot]
)
plot_pca_btn.click(
on_plot_pca,
inputs=[all_predictions_state, view_selector],
outputs=[pca_plot]
)
if __name__ == "__main__":
iface.launch(share=True)