|
import streamlit as st |
|
|
|
from ase.visualize import view |
|
|
|
try: |
|
from StringIO import StringIO |
|
except ImportError: |
|
from io import StringIO |
|
|
|
import streamlit.components.v1 as components |
|
from scipy.spatial.distance import cdist |
|
|
|
import ase |
|
|
|
import functools |
|
import e3x |
|
from flax import linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import optax |
|
|
|
|
|
import pandas as pd |
|
from dcmnet.modules import MessagePassingModel |
|
from dcmnet.utils import clip_colors, apply_model |
|
from dcmnet.data import prepare_batches |
|
from dcmnet.plotting import plot_model |
|
|
|
RANDOM_NUMBER = 0 |
|
filename = "test" |
|
data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2) |
|
|
|
|
|
|
|
features = 16 |
|
max_degree = 2 |
|
num_iterations = 2 |
|
num_basis_functions = 8 |
|
cutoff = 4.0 |
|
|
|
|
|
DCM1 = MessagePassingModel( |
|
features=features, |
|
max_degree=max_degree, |
|
num_iterations=num_iterations, |
|
num_basis_functions=num_basis_functions, |
|
cutoff=cutoff, |
|
n_dcm=1, |
|
) |
|
|
|
|
|
DCM2 = MessagePassingModel( |
|
features=features, |
|
max_degree=max_degree, |
|
num_iterations=num_iterations, |
|
num_basis_functions=num_basis_functions, |
|
cutoff=cutoff, |
|
n_dcm=2, |
|
) |
|
|
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit.Chem import Draw |
|
|
|
def get_grid_points(coordinates): |
|
""" |
|
create a uniform grid of points around the molecule, |
|
starting from minimum and maximum coordinates of the molecule (plus minus some padding) |
|
:param coordinates: |
|
:return: |
|
""" |
|
bounds = np.array([np.min(coordinates, axis=0), |
|
np.max(coordinates, axis=0)]) |
|
padding = 3.0 |
|
bounds = bounds + np.array([-1, 1])[:, None] * padding |
|
grid_points = np.meshgrid(*[np.linspace(a, b, 15) |
|
for a, b in zip(bounds[0], bounds[1])]) |
|
|
|
grid_points = np.stack(grid_points, axis=0) |
|
grid_points = np.reshape(grid_points.T, [-1, 3]) |
|
|
|
grid_points = grid_points[ |
|
|
|
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]] |
|
|
|
return grid_points |
|
|
|
|
|
dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl") |
|
dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl") |
|
|
|
smiles = 'C1NCCCC1' |
|
|
|
smiles_mol = Chem.MolFromSmiles(smiles) |
|
rdkit_mol = Chem.AddHs(smiles_mol) |
|
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] |
|
|
|
AllChem.EmbedMolecule(rdkit_mol) |
|
coordinates = rdkit_mol.GetConformer(0).GetPositions() |
|
surface = get_grid_points(coordinates) |
|
|
|
for i, atom in enumerate(smiles_mol.GetAtoms()): |
|
|
|
atom.SetProp("atomNote", str(atom.GetIdx())) |
|
|
|
smiles_image = Draw.MolToImage(smiles_mol) |
|
|
|
|
|
st.image(smiles_image) |
|
|
|
|
|
vdw_surface = surface |
|
max_N_atoms = 60 |
|
max_grid_points = 3143 |
|
max_grid_points - len(vdw_surface) |
|
try: |
|
Z = [np.array([int(_) for _ in elements])] |
|
except: |
|
Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])] |
|
pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))]) |
|
pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))]) |
|
|
|
pad_vdw_surface = [] |
|
_ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000)) |
|
pad_vdw_surface.append(_) |
|
pad_vdw_surface = np.array(pad_vdw_surface) |
|
|
|
|
|
data_batch = dict( |
|
atomic_numbers=jnp.asarray(pad_Z), |
|
positions=jnp.asarray(pad_coords), |
|
mono=jnp.asarray(pad_Z), |
|
ngrid=jnp.array([len(vdw_surface)]), |
|
esp=jnp.asarray([np.zeros(max_grid_points)]), |
|
vdw_surface=jnp.asarray(pad_vdw_surface), |
|
) |
|
|
|
batch_size = 1 |
|
|
|
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size) |
|
|
|
batchID = 0 |
|
errors_train = [] |
|
batch = psi4_test_batches[batchID] |
|
|
|
|
|
dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False) |
|
dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 2, plot=False) |
|
|
|
atoms = dcm1results["atoms"] |
|
dcmol = dcm1results["dcmol"] |
|
dcmol2 = dcm2results["dcmol"] |
|
|
|
|
|
|
|
st.write("Click M to see the distributed charges") |
|
output = StringIO() |
|
(atoms+dcmol).write(output, format="html") |
|
data = output.getvalue() |
|
components.html(data, width=1000, height=1000) |
|
|
|
output = StringIO() |
|
(atoms+dcmol2).write(output, format="html") |
|
data = output.getvalue() |
|
components.html(data, width=1000, height=1000) |
|
|
|
|
|
|