DCMNet / app.py
EricBoi's picture
.
fc2a7e5
import streamlit as st
from ase.visualize import view
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import streamlit.components.v1 as components
from scipy.spatial.distance import cdist
import ase
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
from dcmnet.modules import MessagePassingModel
from dcmnet.utils import clip_colors, apply_model
from dcmnet.data import prepare_batches
from dcmnet.plotting import plot_model
RANDOM_NUMBER = 0
filename = "test"
data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
# Create models
DCM1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
# Create models
DCM2 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=2,
)
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0),
np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(*[np.linspace(a, b, 15)
for a, b in zip(bounds[0], bounds[1])])
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
#np.where(np.all(cdist(grid_points, coordinates) >= (2.0 - 1e-1), axis=-1))[0]]
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]]
return grid_points
dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
smiles = 'C1NCCCC1'
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
# Generate a conformation
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for i, atom in enumerate(smiles_mol.GetAtoms()):
# For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
# display molecule
st.image(smiles_image)
vdw_surface = surface
max_N_atoms = 60
max_grid_points = 3143
max_grid_points - len(vdw_surface)
try:
Z = [np.array([int(_) for _ in elements])]
except:
Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))])
pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))])
pad_vdw_surface = []
_ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000))
pad_vdw_surface.append(_)
pad_vdw_surface = np.array(pad_vdw_surface)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_Z),
positions=jnp.asarray(pad_coords),
mono=jnp.asarray(pad_Z),
ngrid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batchID = 0
errors_train = []
batch = psi4_test_batches[batchID]
#mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False)
dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 2, plot=False)
atoms = dcm1results["atoms"]
dcmol = dcm1results["dcmol"]
dcmol2 = dcm2results["dcmol"]
st.write("Click M to see the distributed charges")
output = StringIO()
(atoms+dcmol).write(output, format="html")
data = output.getvalue()
components.html(data, width=1000, height=1000)
output = StringIO()
(atoms+dcmol2).write(output, format="html")
data = output.getvalue()
components.html(data, width=1000, height=1000)