File size: 4,129 Bytes
3dfe804 e65baf8 15f9c7b 245457d 136d978 b3226b4 c8b0bc8 e0fa269 b3226b4 01eac72 b3226b4 136d978 684fecf b4d9ff9 684fecf e65baf8 684fecf 136d978 e65baf8 684fecf 136d978 684fecf 13967af b4d9ff9 13967af 684fecf 13967af 136d978 b303f65 920b8d5 b3226b4 3bc70f1 b3226b4 e0fa269 a06b2d0 b3226b4 e0fa269 245457d 920b8d5 684fecf b8a68c2 136d978 e0fa269 0cef615 e0fa269 245457d 3dfe804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import streamlit as st
from scipy.spatial.distance import cdist
import ase
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
from dcmnet.modules import MessagePassingModel
from dcmnet.utils import clip_colors, apply_model
from dcmnet.data import prepare_batches
from dcmnet.plotting import plot_model
RANDOM_NUMBER = 0
filename = "test"
data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
# Create models
DCM1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0),
np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(*[np.linspace(a, b, 15)
for a, b in zip(bounds[0], bounds[1])])
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
#np.where(np.all(cdist(grid_points, coordinates) >= (2.0 - 1e-1), axis=-1))[0]]
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]]
return grid_points
test_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
smiles = 'C1NCCCC1'
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
# Generate a conformation
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for i, atom in enumerate(smiles_mol.GetAtoms()):
# For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
# display molecule
st.image(smiles_image)
vdw_surface = surface
max_N_atoms = 60
max_grid_points = 3143
max_grid_points - len(vdw_surface)
try:
Z = [np.array([int(_) for _ in elements])]
except:
Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))])
pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))])
pad_vdw_surface = []
_ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000))
pad_vdw_surface.append(_)
pad_vdw_surface = np.array(pad_vdw_surface)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_Z),
positions=jnp.asarray(pad_coords),
mono=jnp.asarray(pad_Z),
ngrid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batchID = 0
errors_train = []
batch = psi4_test_batches[batchID]
#mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
dcm1results = plot_model(DCM1, test_weights, batch, batch_size, 1, plot=False)
dipo = dcm1results["dipo"]
mono = dcm1results["mono"]
atoms = dcm1results["atoms"]
dcmol = dcm1results["dcmol"]
st.write(dipo)
st.write(mono)
#st.write(MessagePassingModel)
#st.write(test_weights)
#
from ase.visualize import view
display_mol = view(atoms+dcmol, viewer="x3d")
st.write(type(display_mol))
st.html(display_mol)
x = st.slider('Select a value')
st.write(x, 'squared is', x * x)
|