File size: 4,666 Bytes
3dfe804 52c9ab4 e65baf8 15f9c7b 245457d 136d978 b3226b4 c8b0bc8 e0fa269 b3226b4 01eac72 b3226b4 949746a b3226b4 136d978 684fecf b4d9ff9 684fecf e65baf8 684fecf 949746a 136d978 e65baf8 684fecf 136d978 684fecf 13967af b4d9ff9 13967af 684fecf 13967af 136d978 b303f65 920b8d5 b3226b4 3bc70f1 b3226b4 e0fa269 949746a fc2a7e5 b3226b4 e0fa269 949746a 245457d 136d978 e0fa269 949746a 7eb2af5 43fdd4d 949746a e0fa269 949746a 52c9ab4 b3fbae8 949746a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import streamlit as st
from ase.visualize import view
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import streamlit.components.v1 as components
from scipy.spatial.distance import cdist
import ase
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
from dcmnet.modules import MessagePassingModel
from dcmnet.utils import clip_colors, apply_model
from dcmnet.data import prepare_batches
from dcmnet.plotting import plot_model
RANDOM_NUMBER = 0
filename = "test"
data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
# Create models
DCM1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
# Create models
DCM2 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=2,
)
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0),
np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(*[np.linspace(a, b, 15)
for a, b in zip(bounds[0], bounds[1])])
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
#np.where(np.all(cdist(grid_points, coordinates) >= (2.0 - 1e-1), axis=-1))[0]]
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]]
return grid_points
dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
smiles = 'C1NCCCC1'
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
# Generate a conformation
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for i, atom in enumerate(smiles_mol.GetAtoms()):
# For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
# display molecule
st.image(smiles_image)
vdw_surface = surface
max_N_atoms = 60
max_grid_points = 3143
max_grid_points - len(vdw_surface)
try:
Z = [np.array([int(_) for _ in elements])]
except:
Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))])
pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))])
pad_vdw_surface = []
_ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000))
pad_vdw_surface.append(_)
pad_vdw_surface = np.array(pad_vdw_surface)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_Z),
positions=jnp.asarray(pad_coords),
mono=jnp.asarray(pad_Z),
ngrid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batchID = 0
errors_train = []
batch = psi4_test_batches[batchID]
#mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False)
dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 2, plot=False)
atoms = dcm1results["atoms"]
dcmol = dcm1results["dcmol"]
dcmol2 = dcm2results["dcmol"]
st.write("Click M to see the distributed charges")
output = StringIO()
(atoms+dcmol).write(output, format="html")
data = output.getvalue()
components.html(data, width=1000, height=1000)
output = StringIO()
(atoms+dcmol2).write(output, format="html")
data = output.getvalue()
components.html(data, width=1000, height=1000)
|