import streamlit as st from ase.visualize import view try: from StringIO import StringIO except ImportError: from io import StringIO import streamlit.components.v1 as components from scipy.spatial.distance import cdist import ase import functools import e3x from flax import linen as nn import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax import pandas as pd from dcmnet.modules import MessagePassingModel from dcmnet.utils import clip_colors, apply_model from dcmnet.data import prepare_batches from dcmnet.plotting import plot_model RANDOM_NUMBER = 0 filename = "test" data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2) # Model hyperparameters. features = 16 max_degree = 2 num_iterations = 2 num_basis_functions = 8 cutoff = 4.0 # Create models DCM1 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=1, ) # Create models DCM2 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=2, ) from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import Draw def get_grid_points(coordinates): """ create a uniform grid of points around the molecule, starting from minimum and maximum coordinates of the molecule (plus minus some padding) :param coordinates: :return: """ bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)]) padding = 3.0 bounds = bounds + np.array([-1, 1])[:, None] * padding grid_points = np.meshgrid(*[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])]) grid_points = np.stack(grid_points, axis=0) grid_points = np.reshape(grid_points.T, [-1, 3]) # exclude points that are too close to the molecule grid_points = grid_points[ #np.where(np.all(cdist(grid_points, coordinates) >= (2.0 - 1e-1), axis=-1))[0]] np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]] return grid_points dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl") dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl") smiles = 'C1NCCCC1' smiles_mol = Chem.MolFromSmiles(smiles) rdkit_mol = Chem.AddHs(smiles_mol) elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] # Generate a conformation AllChem.EmbedMolecule(rdkit_mol) coordinates = rdkit_mol.GetConformer(0).GetPositions() surface = get_grid_points(coordinates) for i, atom in enumerate(smiles_mol.GetAtoms()): # For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule atom.SetProp("atomNote", str(atom.GetIdx())) smiles_image = Draw.MolToImage(smiles_mol) # display molecule st.image(smiles_image) vdw_surface = surface max_N_atoms = 60 max_grid_points = 3143 max_grid_points - len(vdw_surface) try: Z = [np.array([int(_) for _ in elements])] except: Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])] pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))]) pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))]) pad_vdw_surface = [] _ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000)) pad_vdw_surface.append(_) pad_vdw_surface = np.array(pad_vdw_surface) data_batch = dict( atomic_numbers=jnp.asarray(pad_Z), positions=jnp.asarray(pad_coords), mono=jnp.asarray(pad_Z), ngrid=jnp.array([len(vdw_surface)]), esp=jnp.asarray([np.zeros(max_grid_points)]), vdw_surface=jnp.asarray(pad_vdw_surface), ) batch_size = 1 psi4_test_batches = prepare_batches(data_key, data_batch, batch_size) batchID = 0 errors_train = [] batch = psi4_test_batches[batchID] #mono, dipo = apply_model(DCM1, test_weights, batch, batch_size) dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False) dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 2, plot=False) atoms = dcm1results["atoms"] dcmol = dcm1results["dcmol"] dcmol2 = dcm2results["dcmol"] st.write("Click M to see the distributed charges") output = StringIO() (atoms+dcmol).write(output, format="html") data = output.getvalue() components.html(data, width=1000, height=1000) output = StringIO() (atoms+dcmol2).write(output, format="html") data = output.getvalue() components.html(data, width=1000, height=1000)