import streamlit as st import functools import e3x from flax import linen as nn import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax import pandas as pd from dcmnet.modules import MessagePassingModel from rdkit import Chem from rdkit.Chem import AllChem from dcmnet.psi4_ import * test_weights = pd.read_pickle("wbs/best_0.0_params.pkl") smiles = 'CCNCC' # smiles = 'CC(CC1=CC=CC=C1)NC' smiles_mol = Chem.MolFromSmiles(smiles) rdkit_mol = Chem.AddHs(smiles_mol) elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] # Generate a conformation AllChem.EmbedMolecule(rdkit_mol) coordinates = rdkit_mol.GetConformer(0).GetPositions() surface = get_grid_points(coordinates) for i, atom in enumerate(smiles_mol.GetAtoms()): # For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule atom.SetProp("atomNote", str(atom.GetIdx())) # display molecule st.image(smiles_mol) # Disable future warnings. import warnings warnings.simplefilter(action='ignore', category=FutureWarning) # Initialize PRNGKey for random number generation. key = jax.random.PRNGKey(0) key, rotation_key = jax.random.split(key) rotation = e3x.so3.random_rotation(rotation_key) st.write(rotation) #st.write(MessagePassingModel) #st.write(test_weights) x = st.slider('Select a value') st.write(x, 'squared is', x * x)