|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import os |
|
|
import tempfile |
|
|
import json |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
import pickle |
|
|
import joblib |
|
|
from collections import defaultdict |
|
|
|
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import Draw, AllChem |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
try: |
|
|
from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor |
|
|
import torch_molecule |
|
|
TORCH_MOLECULE_AVAILABLE = True |
|
|
print('torch_molecule version: ', torch_molecule.__version__) |
|
|
except ImportError: |
|
|
TORCH_MOLECULE_AVAILABLE = False |
|
|
print("Warning: torch_molecule not available. Some models may not work.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('property_mapping.json', 'r') as f: |
|
|
PROPERTY_MAPPING = json.load(f) |
|
|
|
|
|
|
|
|
ALL_PROPERTIES = {} |
|
|
PROPERTY_CATEGORIES = defaultdict(list) |
|
|
|
|
|
for prop_abbr, prop_info in PROPERTY_MAPPING.items(): |
|
|
if prop_info['category'] != 'Gas Transport Properties': |
|
|
ALL_PROPERTIES[prop_abbr] = prop_info |
|
|
PROPERTY_CATEGORIES[prop_info['category']].append(prop_abbr) |
|
|
|
|
|
|
|
|
PERMEABILITY_PROPERTIES = [p for p in ALL_PROPERTIES.keys() if ALL_PROPERTIES[p]['category'] == 'Permeability Properties'] |
|
|
|
|
|
print(f"Loaded {len(ALL_PROPERTIES)} properties (excluding Gas Transport Properties)") |
|
|
print(f"Categories: {list(PROPERTY_CATEGORIES.keys())}") |
|
|
|
|
|
all_model_names = ['GREA', 'GCN', 'GIN', 'RandomForest', 'GaussianProcess'] |
|
|
|
|
|
|
|
|
TRAIN_IN_LOG = True |
|
|
HF_REPO_ID = "liuganghuggingface/polymer-prediction-gas-models" |
|
|
|
|
|
|
|
|
DEFAULT_SMILES = """*c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3 |
|
|
*CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21 |
|
|
*C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1""" |
|
|
|
|
|
|
|
|
SELECTIVITY_BOUNDS = { |
|
|
'CO2/CH4': { |
|
|
'x': [1.00E+05, 1.00E-02], |
|
|
'y': [1.00E+05/2.21E+04, 1.00E-02/4.88E-06], |
|
|
'gases': ('CO2', 'CH4') |
|
|
}, |
|
|
'H2/CH4': { |
|
|
'x': [5.00E+04, 2.50E+00], |
|
|
'y': [5.00E+04/8.67E+04, 2.50E+00/5.64E-04], |
|
|
'gases': ('H2', 'CH4') |
|
|
}, |
|
|
'O2/N2': { |
|
|
'x': [5.00E+04, 1.00E-03], |
|
|
'y': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05], |
|
|
'gases': ('O2', 'N2') |
|
|
}, |
|
|
'H2/N2': { |
|
|
'x': [1.00E+05, 1.00E-01], |
|
|
'y': [1.00E+05/1.02E+05, 1.00E-01/9.21E-06], |
|
|
'gases': ('H2', 'N2') |
|
|
}, |
|
|
'CO2/N2': { |
|
|
'x': [1.00E+06, 1.00E-04], |
|
|
'y': [1.00E+06/3.05E+05, 1.00E-04/1.05E-08], |
|
|
'gases': ('CO2', 'N2') |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def smiles_to_mol_image(smiles, img_size=(200, 200)): |
|
|
"""Convert SMILES to molecule image as base64 string.""" |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
|
|
|
img = Draw.MolToImage(mol, size=img_size) |
|
|
buffered = BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
return f"data:image/png;base64,{img_str}" |
|
|
|
|
|
def validate_smiles(smiles_list): |
|
|
"""Validate a list of SMILES strings.""" |
|
|
valid_smiles = [] |
|
|
invalid_smiles = [] |
|
|
|
|
|
for idx, smiles in enumerate(smiles_list): |
|
|
smiles = smiles.strip() |
|
|
if not smiles: |
|
|
continue |
|
|
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is not None: |
|
|
standardized = Chem.MolToSmiles(mol, isomericSmiles=True) |
|
|
valid_smiles.append((idx, smiles, standardized)) |
|
|
else: |
|
|
invalid_smiles.append((idx, smiles)) |
|
|
|
|
|
report = f"✅ Valid SMILES: {len(valid_smiles)}\n" |
|
|
report += f"❌ Invalid SMILES: {len(invalid_smiles)}\n" |
|
|
|
|
|
if invalid_smiles: |
|
|
report += "\n**Invalid SMILES detected:**\n" |
|
|
for idx, smiles in invalid_smiles: |
|
|
report += f" - Line {idx + 1}: `{smiles}`\n" |
|
|
report += "\n⚠️ **Please remove or correct the invalid SMILES before proceeding.**" |
|
|
|
|
|
return valid_smiles, invalid_smiles, report |
|
|
|
|
|
def smiles_to_fingerprint(smiles_list, n_bits=2048): |
|
|
"""Convert SMILES to Morgan fingerprints.""" |
|
|
fingerprints = [] |
|
|
for smiles in smiles_list: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is not None: |
|
|
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits) |
|
|
fingerprints.append(np.array(fp)) |
|
|
else: |
|
|
fingerprints.append(np.zeros(n_bits)) |
|
|
return np.array(fingerprints) |
|
|
|
|
|
|
|
|
|
|
|
def load_all_models(): |
|
|
"""Load all available models from HuggingFace Hub.""" |
|
|
print("Loading models from HuggingFace Hub (CPU only)...") |
|
|
loaded_models = {} |
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
for model_name in all_model_names: |
|
|
loaded_models[model_name] = {} |
|
|
|
|
|
|
|
|
for prop_abbr in ALL_PROPERTIES.keys(): |
|
|
model_filename = f"{model_name.lower()}_{prop_abbr.lower()}" |
|
|
|
|
|
try: |
|
|
if model_name in ['GREA', 'GCN', 'GIN']: |
|
|
filename = f"{model_filename}.pt" |
|
|
|
|
|
if not TORCH_MOLECULE_AVAILABLE: |
|
|
continue |
|
|
|
|
|
print(f" Downloading {filename}...") |
|
|
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) |
|
|
|
|
|
if model_name == 'GREA': |
|
|
model = GREAMolecularPredictor(device='cpu') |
|
|
elif model_name == 'GCN': |
|
|
model = GNNMolecularPredictor(gnn_type='gcn-virtual', device='cpu') |
|
|
elif model_name == 'GIN': |
|
|
model = GNNMolecularPredictor(gnn_type='gin-virtual', device='cpu') |
|
|
|
|
|
model.load_from_local(model_path) |
|
|
loaded_models[model_name][prop_abbr] = (model, 'torch_molecule') |
|
|
print(f" ✓ Loaded {model_name} for {prop_abbr}") |
|
|
|
|
|
else: |
|
|
filename = f"{model_filename}.pkl" |
|
|
print(f" Downloading {filename}...") |
|
|
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) |
|
|
model = joblib.load(model_path) |
|
|
loaded_models[model_name][prop_abbr] = (model, 'sklearn') |
|
|
print(f" ✓ Loaded {model_name} for {prop_abbr}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ❌ Error loading {model_name} for {prop_abbr}: {e}") |
|
|
|
|
|
print("Model loading complete!") |
|
|
return loaded_models |
|
|
|
|
|
PRELOADED_MODELS = load_all_models() |
|
|
|
|
|
|
|
|
|
|
|
def predict_properties(smiles_list, selected_models, progress=gr.Progress()): |
|
|
"""Predict properties for a list of SMILES.""" |
|
|
if not selected_models: |
|
|
return None, "❌ Please select at least one model." |
|
|
|
|
|
progress(0.1, desc="Validating SMILES...") |
|
|
valid_smiles, invalid_smiles, validation_report = validate_smiles(smiles_list) |
|
|
|
|
|
if invalid_smiles: |
|
|
return None, validation_report |
|
|
|
|
|
if not valid_smiles: |
|
|
return None, "❌ No SMILES provided." |
|
|
|
|
|
indices, original_smiles, standardized_smiles = zip(*valid_smiles) |
|
|
|
|
|
all_predictions = { |
|
|
'original_smiles': list(original_smiles), |
|
|
'standardized_smiles': list(standardized_smiles), |
|
|
'predictions': {}, |
|
|
'predictions_log': {} |
|
|
} |
|
|
|
|
|
|
|
|
X_fp = None |
|
|
needs_fingerprints = any(model in selected_models for model in ['RandomForest', 'GaussianProcess']) |
|
|
if needs_fingerprints: |
|
|
progress(0.2, desc="Computing fingerprints...") |
|
|
X_fp = smiles_to_fingerprint(standardized_smiles) |
|
|
|
|
|
model_errors = [] |
|
|
|
|
|
|
|
|
available_props = list(ALL_PROPERTIES.keys()) |
|
|
total_predictions = len(available_props) * len(selected_models) |
|
|
pred_count = 0 |
|
|
|
|
|
for model_name in selected_models: |
|
|
all_predictions['predictions'][model_name] = {} |
|
|
all_predictions['predictions_log'][model_name] = {} |
|
|
|
|
|
for prop in available_props: |
|
|
progress(0.2 + 0.7 * pred_count / total_predictions, |
|
|
desc=f"Predicting {prop} with {model_name}...") |
|
|
|
|
|
if model_name not in PRELOADED_MODELS or prop not in PRELOADED_MODELS[model_name]: |
|
|
model_errors.append(f"{model_name} for {prop}") |
|
|
pred_count += 1 |
|
|
continue |
|
|
|
|
|
model, model_type = PRELOADED_MODELS[model_name][prop] |
|
|
|
|
|
try: |
|
|
if model_type == 'torch_molecule': |
|
|
with torch.no_grad(): |
|
|
predictions_dict = model.predict(list(standardized_smiles)) |
|
|
predictions = predictions_dict['prediction'] |
|
|
else: |
|
|
predictions = model.predict(X_fp) |
|
|
|
|
|
if isinstance(predictions, np.ndarray) and predictions.ndim > 1: |
|
|
predictions = predictions.flatten() |
|
|
|
|
|
|
|
|
prop_category = ALL_PROPERTIES[prop]['category'] |
|
|
if prop_category == 'Permeability Properties' and TRAIN_IN_LOG: |
|
|
predictions_original = 10**predictions |
|
|
all_predictions['predictions'][model_name][prop] = predictions_original |
|
|
all_predictions['predictions_log'][model_name][prop] = predictions |
|
|
else: |
|
|
|
|
|
all_predictions['predictions'][model_name][prop] = predictions |
|
|
|
|
|
all_predictions['predictions_log'][model_name][prop] = np.log10(np.maximum(np.abs(predictions), 1e-10)) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error predicting {model_name} for {prop}: {e}") |
|
|
model_errors.append(f"{model_name} for {prop}") |
|
|
|
|
|
pred_count += 1 |
|
|
|
|
|
|
|
|
progress(0.9, desc="Computing averages...") |
|
|
all_predictions['predictions']['Average'] = {} |
|
|
all_predictions['predictions_log']['Average'] = {} |
|
|
|
|
|
for prop in available_props: |
|
|
prop_predictions = [] |
|
|
prop_predictions_log = [] |
|
|
for model_name in selected_models: |
|
|
if model_name in all_predictions['predictions'] and prop in all_predictions['predictions'][model_name]: |
|
|
prop_predictions.append(all_predictions['predictions'][model_name][prop]) |
|
|
prop_predictions_log.append(all_predictions['predictions_log'][model_name][prop]) |
|
|
|
|
|
if prop_predictions: |
|
|
if len(prop_predictions) > 1: |
|
|
all_predictions['predictions']['Average'][prop] = np.mean(np.array(prop_predictions), axis=0) |
|
|
all_predictions['predictions_log']['Average'][prop] = np.mean(np.array(prop_predictions_log), axis=0) |
|
|
else: |
|
|
all_predictions['predictions']['Average'][prop] = prop_predictions[0] |
|
|
all_predictions['predictions_log']['Average'][prop] = prop_predictions_log[0] |
|
|
|
|
|
report = validation_report + "\n" |
|
|
if model_errors: |
|
|
report += f"\n⚠️ Some models were not available: {len(set(model_errors))} property-model combinations\n" |
|
|
report += f"\n✅ Successfully predicted properties for {len(valid_smiles)} molecules using {len(selected_models)} model(s)." |
|
|
report += f"\n📊 Attempted to predict {len(available_props)} properties across all categories." |
|
|
report += f"\n💻 All predictions run on CPU." |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return all_predictions, report |
|
|
|
|
|
|
|
|
|
|
|
def create_results_gallery_html(all_predictions, selected_view='Average', selected_category='All', max_display=30): |
|
|
""" |
|
|
Create an HTML gallery showing top molecules with their structures grouped by category. |
|
|
""" |
|
|
if all_predictions is None: |
|
|
return "<p>No data available.</p>" |
|
|
|
|
|
num_molecules = len(all_predictions['original_smiles']) |
|
|
display_count = min(num_molecules, max_display) |
|
|
|
|
|
|
|
|
if selected_view not in all_predictions['predictions']: |
|
|
return "<p>No predictions available for selected view.</p>" |
|
|
|
|
|
predictions = all_predictions['predictions'][selected_view] |
|
|
|
|
|
|
|
|
if selected_category == 'All': |
|
|
display_categories = PROPERTY_CATEGORIES |
|
|
else: |
|
|
display_categories = {selected_category: PROPERTY_CATEGORIES[selected_category]} |
|
|
|
|
|
html = f""" |
|
|
<div style='font-family: Arial, sans-serif; color: #000;'> |
|
|
<h3 style='color: #1a1a1a; margin-bottom: 10px;'>Prediction Results - {selected_category}</h3> |
|
|
<p style='color: #333; font-size: 1.1em; font-weight: 500;'><strong>Showing top {display_count} of {num_molecules} molecules.</strong></p> |
|
|
{f"<p style='color: #d97706; font-weight: 500; font-size: 1.05em;'>⬇️ Download the CSV file below to see all {num_molecules} results.</p>" if num_molecules > max_display else ""} |
|
|
""" |
|
|
|
|
|
|
|
|
for idx in range(display_count): |
|
|
smiles = all_predictions['original_smiles'][idx] |
|
|
mol_img = smiles_to_mol_image(smiles, img_size=(250, 250)) |
|
|
|
|
|
html += f""" |
|
|
<div style='border: 2px solid #cbd5e1; border-radius: 12px; padding: 20px; margin: 20px 0; background: #ffffff; box-shadow: 0 2px 8px rgba(0,0,0,0.1);'> |
|
|
<div style='display: flex; gap: 25px; align-items: flex-start;'> |
|
|
<div style='flex-shrink: 0; background: #f8fafc; padding: 10px; border-radius: 8px;'> |
|
|
{"<img src='" + mol_img + "' style='width: 250px; height: 250px; display: block;'/>" if mol_img else "<p style='color: #ef4444; font-weight: 500;'>Invalid structure</p>"} |
|
|
</div> |
|
|
<div style='flex-grow: 1;'> |
|
|
<h4 style='color: #0f172a; margin: 0 0 15px 0; font-size: 1.3em; border-bottom: 2px solid #e2e8f0; padding-bottom: 8px;'>Molecule {idx + 1}</h4> |
|
|
<p style='color: #475569; margin-bottom: 15px; background: #f1f5f9; padding: 10px; border-radius: 6px; word-break: break-all;'> |
|
|
<strong style='color: #1e293b;'>SMILES:</strong> |
|
|
<code style='background: #e2e8f0; padding: 4px 8px; border-radius: 4px; font-size: 0.9em; color: #334155;'>{smiles[:100]}{"..." if len(smiles) > 100 else ""}</code> |
|
|
</p> |
|
|
""" |
|
|
|
|
|
|
|
|
for category, props in display_categories.items(): |
|
|
category_props = [p for p in props if p in predictions] |
|
|
if not category_props: |
|
|
continue |
|
|
|
|
|
html += f""" |
|
|
<div style='margin-top: 15px; background: #f8fafc; padding: 12px; border-radius: 8px; border-left: 4px solid #3b82f6;'> |
|
|
<strong style='color: #1e40af; font-size: 1.1em; display: block; margin-bottom: 8px;'>{category}:</strong> |
|
|
""" |
|
|
|
|
|
|
|
|
for prop in category_props: |
|
|
if prop in predictions: |
|
|
value = predictions[prop][idx] |
|
|
prop_info = ALL_PROPERTIES[prop] |
|
|
|
|
|
html += f""" |
|
|
<div style='margin: 6px 0; padding: 6px 10px; background: #ffffff; border-radius: 4px;'> |
|
|
<span style='color: #334155; font-weight: 500;'>• {prop_info['full_name']}</span> |
|
|
<span style='color: #64748b;'>({prop_info['unit']}):</span> |
|
|
<span style='color: #0369a1; font-weight: bold; font-size: 1.05em;'>{value:.3f}</span> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
html += "</div>" |
|
|
|
|
|
html += """ |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
html += "</div>" |
|
|
return html |
|
|
|
|
|
def format_full_predictions_csv(all_predictions, selected_view='Average'): |
|
|
"""Format all predictions into CSV for download.""" |
|
|
if all_predictions is None: |
|
|
return None |
|
|
|
|
|
df = pd.DataFrame({ |
|
|
'SMILES': all_predictions['original_smiles'] |
|
|
}) |
|
|
|
|
|
predictions = all_predictions['predictions'][selected_view] |
|
|
|
|
|
|
|
|
for category, props in PROPERTY_CATEGORIES.items(): |
|
|
for prop in props: |
|
|
if prop in predictions: |
|
|
prop_info = ALL_PROPERTIES[prop] |
|
|
col_name = f"{prop_info['full_name']} ({prop_info['unit']})" |
|
|
df[col_name] = predictions[prop] |
|
|
|
|
|
return df |
|
|
|
|
|
def create_selectivity_plot_with_images(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4', max_display=30): |
|
|
"""Create selectivity plot with molecule images on hover.""" |
|
|
if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS: |
|
|
return None |
|
|
|
|
|
bounds = SELECTIVITY_BOUNDS[selectivity_pair] |
|
|
gas1, gas2 = bounds['gases'] |
|
|
|
|
|
if selected_view not in all_predictions['predictions_log']: |
|
|
return None |
|
|
|
|
|
if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]: |
|
|
return None |
|
|
|
|
|
|
|
|
num_molecules = len(all_predictions['original_smiles']) |
|
|
display_count = min(num_molecules, max_display) |
|
|
|
|
|
gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1][:display_count] |
|
|
gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2][:display_count] |
|
|
|
|
|
gas1_perm = 10**gas1_perm_log |
|
|
gas2_perm = 10**gas2_perm_log |
|
|
|
|
|
gas1_perm = np.maximum(gas1_perm, 1e-10) |
|
|
gas2_perm = np.maximum(gas2_perm, 1e-10) |
|
|
|
|
|
selectivity = gas1_perm / gas2_perm |
|
|
|
|
|
|
|
|
x1, x2 = bounds['x'] |
|
|
y1, y2 = bounds['y'] |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=[x1, x2], |
|
|
y=[y1, y2], |
|
|
mode='lines', |
|
|
name='2008 Upper Bound', |
|
|
line=dict(color='red', width=3, dash='dash'), |
|
|
hoverinfo='name' |
|
|
)) |
|
|
|
|
|
|
|
|
x_log = np.log10(gas1_perm) |
|
|
y_log = np.log10(selectivity) |
|
|
|
|
|
x1_log, x2_log = np.log10(x1), np.log10(x2) |
|
|
y1_log, y2_log = np.log10(y1), np.log10(y2) |
|
|
a = (y1_log - y2_log) / (x1_log - x2_log) |
|
|
b = y1_log - a * x1_log |
|
|
|
|
|
y_bound = a * x_log + b |
|
|
above_bound = y_log > y_bound |
|
|
|
|
|
|
|
|
smiles_list = all_predictions['original_smiles'][:display_count] |
|
|
hover_texts = [] |
|
|
for i, smiles in enumerate(smiles_list): |
|
|
truncated = smiles if len(smiles) <= 80 else smiles[:77] + '...' |
|
|
status = "Above Bound" if above_bound[i] else "Below Bound" |
|
|
hover_text = (f"SMILES: {truncated}<br>" |
|
|
f"{gas1}: {gas1_perm[i]:.3f} Barrer<br>" |
|
|
f"{gas2}: {gas2_perm[i]:.3f} Barrer<br>" |
|
|
f"Selectivity: {selectivity[i]:.3f}<br>" |
|
|
f"Status: {status}") |
|
|
hover_texts.append(hover_text) |
|
|
|
|
|
|
|
|
if np.any(above_bound): |
|
|
fig.add_trace(go.Scatter( |
|
|
x=gas1_perm[above_bound], |
|
|
y=selectivity[above_bound], |
|
|
mode='markers', |
|
|
name='Above Bound', |
|
|
marker=dict(color='green', size=10, symbol='circle'), |
|
|
text=[hover_texts[i] for i in range(len(hover_texts)) if above_bound[i]], |
|
|
hovertemplate='%{text}<extra></extra>' |
|
|
)) |
|
|
|
|
|
if np.any(~above_bound): |
|
|
fig.add_trace(go.Scatter( |
|
|
x=gas1_perm[~above_bound], |
|
|
y=selectivity[~above_bound], |
|
|
mode='markers', |
|
|
name='Below Bound', |
|
|
marker=dict(color='blue', size=8, symbol='circle'), |
|
|
text=[hover_texts[i] for i in range(len(hover_texts)) if not above_bound[i]], |
|
|
hovertemplate='%{text}<extra></extra>' |
|
|
)) |
|
|
|
|
|
fig.update_xaxes( |
|
|
title=f"{gas1} Permeability (Barrer)", |
|
|
type="log", |
|
|
gridcolor='lightgray' |
|
|
) |
|
|
|
|
|
fig.update_yaxes( |
|
|
title=f"{gas1}/{gas2} Selectivity", |
|
|
type="log", |
|
|
gridcolor='lightgray' |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"{gas1}/{gas2} Selectivity Plot (Top {display_count} molecules)", |
|
|
hovermode='closest', |
|
|
showlegend=True, |
|
|
plot_bgcolor='white', |
|
|
height=600 |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
def generate_all_selectivity_plots(all_predictions, selected_view='Average', max_display=30): |
|
|
"""Generate all selectivity plots at once.""" |
|
|
if all_predictions is None: |
|
|
return {} |
|
|
|
|
|
all_plots = {} |
|
|
for selectivity_pair in SELECTIVITY_BOUNDS.keys(): |
|
|
plot = create_selectivity_plot_with_images(all_predictions, selected_view, selectivity_pair, max_display) |
|
|
if plot is not None: |
|
|
all_plots[selectivity_pair] = plot |
|
|
|
|
|
return all_plots |
|
|
|
|
|
def create_pca_plot(all_predictions, selected_view='Average', max_display=30): |
|
|
"""Create PCA plot with all properties on hover.""" |
|
|
if all_predictions is None: |
|
|
return None |
|
|
|
|
|
num_molecules = len(all_predictions['original_smiles']) |
|
|
display_count = min(num_molecules, max_display) |
|
|
|
|
|
smiles_list = all_predictions['standardized_smiles'][:display_count] |
|
|
|
|
|
|
|
|
X_fp = smiles_to_fingerprint(smiles_list) |
|
|
|
|
|
|
|
|
pca = PCA(n_components=2) |
|
|
X_pca = pca.fit_transform(X_fp) |
|
|
|
|
|
|
|
|
predictions = all_predictions['predictions'][selected_view] |
|
|
hover_texts = [] |
|
|
|
|
|
for idx in range(display_count): |
|
|
smiles = all_predictions['original_smiles'][idx] |
|
|
truncated = smiles if len(smiles) <= 60 else smiles[:57] + '...' |
|
|
hover_text = f"SMILES: {truncated}<br>" |
|
|
|
|
|
for category, props in PROPERTY_CATEGORIES.items(): |
|
|
hover_text += f"<br><b>{category}:</b><br>" |
|
|
for prop in props: |
|
|
if prop in predictions: |
|
|
value = predictions[prop][idx] |
|
|
prop_info = ALL_PROPERTIES[prop] |
|
|
hover_text += f" {prop_info['full_name']}: {value:.3f} {prop_info['unit']}<br>" |
|
|
|
|
|
hover_texts.append(hover_text) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=X_pca[:, 0], |
|
|
y=X_pca[:, 1], |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
size=10, |
|
|
color=np.arange(display_count), |
|
|
colorscale='Viridis', |
|
|
showscale=True, |
|
|
colorbar=dict(title="Molecule #") |
|
|
), |
|
|
text=hover_texts, |
|
|
hovertemplate='%{text}<extra></extra>' |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"PCA Visualization of Molecular Structures (Top {display_count} molecules)", |
|
|
xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)", |
|
|
yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)", |
|
|
hovermode='closest', |
|
|
plot_bgcolor='white', |
|
|
height=600 |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
available_models = [m for m in all_model_names if m in PRELOADED_MODELS and PRELOADED_MODELS[m]] |
|
|
|
|
|
if not available_models: |
|
|
print("⚠️ WARNING: No models loaded!") |
|
|
available_models = all_model_names |
|
|
|
|
|
with gr.Blocks(title="Polymer Property Prediction", theme=gr.themes.Soft()) as iface: |
|
|
gr.Markdown(""" |
|
|
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; margin-bottom: 20px;"> |
|
|
<h1 style="color: white; margin: 0; font-size: 2.5em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);">🔬 Polymer Property Prediction</h1> |
|
|
<p style="font-size: 1.2em; color: #f0f0f0; margin: 10px 0 0 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.3);">Predict electronic, dielectric & optical, thermal, physical & thermodynamic, and gas permeability properties</p> |
|
|
<div style="margin-top: 15px;"> |
|
|
<a href="https://github.com/liugangcode/torch-molecule" target="_blank" |
|
|
style="color: #fff; text-decoration: none; background: rgba(255,255,255,0.2); padding: 10px 20px; border-radius: 20px; font-weight: 500; backdrop-filter: blur(10px);"> |
|
|
💻 Powered by torch-molecule & sklearn |
|
|
</a> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### 📝 Input SMILES") |
|
|
smiles_text = gr.Textbox( |
|
|
label="Enter SMILES (one per line)", |
|
|
placeholder="Enter polymer SMILES strings, one per line...", |
|
|
lines=8, |
|
|
value=DEFAULT_SMILES |
|
|
) |
|
|
|
|
|
smiles_file = gr.File( |
|
|
label="Or upload a file (.txt, .csv, .smi)", |
|
|
file_types=[".txt", ".csv", ".smi"] |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### ⚙️ Model Selection") |
|
|
model_selector = gr.CheckboxGroup( |
|
|
choices=available_models, |
|
|
label="Select Models for Ensemble Prediction", |
|
|
value=[available_models[0]] if available_models else [], |
|
|
info="Choose one or more models. Multiple models will be averaged for robust predictions." |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("🔮 Predict Properties", variant="primary", size="lg") |
|
|
|
|
|
prediction_status = gr.Textbox(label="📊 Status", lines=4, show_label=True) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); border-radius: 10px;'> |
|
|
<h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📊 Prediction Results</h2> |
|
|
<p style='color: #fef3c7; margin: 5px 0 0 0; font-size: 0.9em;'>Results include: Electronic (bandgap, ionization energy), Dielectric & Optical (refractive index), Thermal (Tg, Tm, conductivity), Physical (density, FFV, radius of gyration), and Gas Permeability properties</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
view_selector = gr.Radio( |
|
|
choices=['Average'], |
|
|
label="Select Model View", |
|
|
value='Average', |
|
|
visible=False |
|
|
) |
|
|
|
|
|
category_selector = gr.Radio( |
|
|
choices=['All'] + list(PROPERTY_CATEGORIES.keys()), |
|
|
label="Select Property Category", |
|
|
value='All', |
|
|
info="Choose a category to view specific properties" |
|
|
) |
|
|
|
|
|
results_html = gr.HTML(label="Results with Molecular Structures") |
|
|
|
|
|
download_btn = gr.DownloadButton( |
|
|
label="📥 Download All Results (CSV)", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); border-radius: 10px;'> |
|
|
<h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📈 Interactive Visualizations</h2> |
|
|
<p style='color: #f0f9ff; margin: 5px 0 0 0; font-size: 0.95em;'>Limited to top 30 molecules for performance</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("🎯 Gas Selectivity Analysis"): |
|
|
gr.Markdown(""" |
|
|
<div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'> |
|
|
<p style='margin: 0; color: #1e40af; font-weight: 500;'> |
|
|
Analyze gas separation performance against the 2008 Robeson upper bounds |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
plot_selectivity_btn = gr.Button("📊 Generate All Selectivity Plots", variant="secondary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
selectivity_pair_selector = gr.Radio( |
|
|
choices=list(SELECTIVITY_BOUNDS.keys()), |
|
|
label="Select Gas Pair to View", |
|
|
value='CO2/CH4', |
|
|
visible=False |
|
|
) |
|
|
|
|
|
selectivity_plot = gr.Plot(label="Selectivity Plot") |
|
|
|
|
|
with gr.TabItem("🗺️ PCA Visualization"): |
|
|
gr.Markdown(""" |
|
|
<div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'> |
|
|
<p style='margin: 0; color: #1e40af; font-weight: 500;'> |
|
|
Explore the chemical space using PCA dimensionality reduction on molecular fingerprints |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
plot_pca_btn = gr.Button("📊 Generate PCA Plot", variant="secondary") |
|
|
pca_plot = gr.Plot(label="PCA Plot") |
|
|
|
|
|
|
|
|
all_predictions_state = gr.State(None) |
|
|
|
|
|
|
|
|
def on_predict(text_input, file_input, selected_models): |
|
|
|
|
|
smiles_list = [] |
|
|
if text_input and text_input.strip(): |
|
|
smiles_list.extend([line.strip() for line in text_input.strip().split('\n') if line.strip()]) |
|
|
|
|
|
if file_input is not None: |
|
|
try: |
|
|
file_path = file_input if isinstance(file_input, str) else file_input.name |
|
|
if file_path.endswith('.csv'): |
|
|
df = pd.read_csv(file_input if isinstance(file_input, str) else file_input.name) |
|
|
if 'SMILES' in df.columns: |
|
|
smiles_list.extend(df['SMILES'].dropna().astype(str).tolist()) |
|
|
else: |
|
|
if isinstance(file_input, str): |
|
|
with open(file_input, 'r') as f: |
|
|
lines = f.readlines() |
|
|
else: |
|
|
content = file_input.read() |
|
|
if isinstance(content, bytes): |
|
|
content = content.decode('utf-8') |
|
|
lines = content.strip().split('\n') |
|
|
smiles_list.extend([line.strip() for line in lines if line.strip()]) |
|
|
except Exception as e: |
|
|
return None, "", f"❌ Error reading file: {str(e)}", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) |
|
|
|
|
|
if not smiles_list: |
|
|
return None, "", "❌ Please provide SMILES strings.", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) |
|
|
|
|
|
|
|
|
unique_smiles = list(dict.fromkeys(smiles_list)) |
|
|
|
|
|
|
|
|
all_predictions, report = predict_properties(unique_smiles, selected_models) |
|
|
|
|
|
if all_predictions is None: |
|
|
return None, "", report, gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) |
|
|
|
|
|
|
|
|
results_gallery = create_results_gallery_html(all_predictions, 'Average', 'All', max_display=30) |
|
|
|
|
|
|
|
|
df_full = format_full_predictions_csv(all_predictions, 'Average') |
|
|
temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') |
|
|
df_full.to_csv(temp_csv.name, index=False) |
|
|
temp_csv.close() |
|
|
|
|
|
|
|
|
view_options = ['Average'] + [m for m in selected_models if m in all_predictions['predictions']] |
|
|
view_selector_update = gr.Radio( |
|
|
choices=view_options, |
|
|
value='Average', |
|
|
visible=True |
|
|
) |
|
|
|
|
|
category_selector_update = gr.Radio( |
|
|
choices=['All'] + list(PROPERTY_CATEGORIES.keys()), |
|
|
value='All', |
|
|
visible=True |
|
|
) |
|
|
|
|
|
return ( |
|
|
all_predictions, |
|
|
results_gallery, |
|
|
report, |
|
|
view_selector_update, |
|
|
category_selector_update, |
|
|
gr.DownloadButton( |
|
|
label="📥 Download All Results (CSV)", |
|
|
value=temp_csv.name, |
|
|
visible=True |
|
|
) |
|
|
) |
|
|
|
|
|
def on_view_or_category_change(all_predictions, selected_view, selected_category): |
|
|
if all_predictions is None: |
|
|
return "", gr.DownloadButton(visible=False) |
|
|
|
|
|
results_gallery = create_results_gallery_html(all_predictions, selected_view, selected_category, max_display=30) |
|
|
|
|
|
df_full = format_full_predictions_csv(all_predictions, selected_view) |
|
|
temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') |
|
|
df_full.to_csv(temp_csv.name, index=False) |
|
|
temp_csv.close() |
|
|
|
|
|
return results_gallery, gr.DownloadButton( |
|
|
label=f"📥 Download {selected_view} Results (CSV)", |
|
|
value=temp_csv.name, |
|
|
visible=True |
|
|
) |
|
|
|
|
|
def on_plot_selectivity(all_predictions, selected_view): |
|
|
"""Generate all selectivity plots at once.""" |
|
|
if all_predictions is None: |
|
|
return {}, gr.Radio(visible=False), None |
|
|
|
|
|
|
|
|
all_plots = generate_all_selectivity_plots(all_predictions, selected_view, max_display=30) |
|
|
|
|
|
if not all_plots: |
|
|
return {}, gr.Radio(visible=False), None |
|
|
|
|
|
|
|
|
first_pair = 'CO2/CH4' if 'CO2/CH4' in all_plots else list(all_plots.keys())[0] |
|
|
|
|
|
return ( |
|
|
all_plots, |
|
|
gr.Radio( |
|
|
choices=list(all_plots.keys()), |
|
|
value=first_pair, |
|
|
visible=True |
|
|
), |
|
|
all_plots[first_pair] |
|
|
) |
|
|
|
|
|
def on_selectivity_pair_change(all_plots_dict, selected_pair): |
|
|
"""Switch between pre-generated selectivity plots.""" |
|
|
if not all_plots_dict or selected_pair not in all_plots_dict: |
|
|
return None |
|
|
return all_plots_dict[selected_pair] |
|
|
|
|
|
def on_plot_pca(all_predictions, selected_view): |
|
|
if all_predictions is None: |
|
|
return None |
|
|
return create_pca_plot(all_predictions, selected_view, max_display=30) |
|
|
|
|
|
|
|
|
all_selectivity_plots_state = gr.State({}) |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
on_predict, |
|
|
inputs=[smiles_text, smiles_file, model_selector], |
|
|
outputs=[all_predictions_state, results_html, prediction_status, view_selector, category_selector, download_btn] |
|
|
) |
|
|
|
|
|
view_selector.change( |
|
|
on_view_or_category_change, |
|
|
inputs=[all_predictions_state, view_selector, category_selector], |
|
|
outputs=[results_html, download_btn] |
|
|
) |
|
|
|
|
|
category_selector.change( |
|
|
on_view_or_category_change, |
|
|
inputs=[all_predictions_state, view_selector, category_selector], |
|
|
outputs=[results_html, download_btn] |
|
|
) |
|
|
|
|
|
plot_selectivity_btn.click( |
|
|
on_plot_selectivity, |
|
|
inputs=[all_predictions_state, view_selector], |
|
|
outputs=[all_selectivity_plots_state, selectivity_pair_selector, selectivity_plot] |
|
|
) |
|
|
|
|
|
selectivity_pair_selector.change( |
|
|
on_selectivity_pair_change, |
|
|
inputs=[all_selectivity_plots_state, selectivity_pair_selector], |
|
|
outputs=[selectivity_plot] |
|
|
) |
|
|
|
|
|
plot_pca_btn.click( |
|
|
on_plot_pca, |
|
|
inputs=[all_predictions_state, view_selector], |
|
|
outputs=[pca_plot] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch(share=True) |
|
|
|
|
|
|