|
import streamlit as st |
|
|
|
import functools |
|
import e3x |
|
from flax import linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import optax |
|
|
|
|
|
import pandas as pd |
|
from dcmnet.modules import MessagePassingModel |
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
|
|
from dcmnet.psi4_ import * |
|
|
|
|
|
test_weights = pd.read_pickle("wbs/best_0.0_params.pkl") |
|
|
|
smiles = 'CCNCC' |
|
|
|
|
|
smiles_mol = Chem.MolFromSmiles(smiles) |
|
rdkit_mol = Chem.AddHs(smiles_mol) |
|
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] |
|
|
|
AllChem.EmbedMolecule(rdkit_mol) |
|
coordinates = rdkit_mol.GetConformer(0).GetPositions() |
|
surface = get_grid_points(coordinates) |
|
|
|
for i, atom in enumerate(smiles_mol.GetAtoms()): |
|
|
|
atom.SetProp("atomNote", str(atom.GetIdx())) |
|
|
|
st.image(smiles_mol) |
|
|
|
|
|
|
|
import warnings |
|
warnings.simplefilter(action='ignore', category=FutureWarning) |
|
|
|
|
|
key = jax.random.PRNGKey(0) |
|
|
|
key, rotation_key = jax.random.split(key) |
|
rotation = e3x.so3.random_rotation(rotation_key) |
|
|
|
st.write(rotation) |
|
|
|
|
|
|
|
|
|
|
|
x = st.slider('Select a value') |
|
st.write(x, 'squared is', x * x) |
|
|
|
|