DCMNet / app.py
EricBoi's picture
.
684fecf
raw
history blame
1.42 kB
import streamlit as st
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
from dcmnet.modules import MessagePassingModel
from rdkit import Chem
from rdkit.Chem import AllChem
from dcmnet.psi4_ import *
test_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
smiles = 'CCNCC'
# smiles = 'CC(CC1=CC=CC=C1)NC'
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
# Generate a conformation
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for i, atom in enumerate(smiles_mol.GetAtoms()):
# For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
atom.SetProp("atomNote", str(atom.GetIdx()))
# display molecule
st.image(smiles_mol)
# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)
key, rotation_key = jax.random.split(key)
rotation = e3x.so3.random_rotation(rotation_key)
st.write(rotation)
#st.write(MessagePassingModel)
#st.write(test_weights)
x = st.slider('Select a value')
st.write(x, 'squared is', x * x)