File size: 7,293 Bytes
17e1388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
adapter:
adapter:
_target_: mattergen.adapter.GemNetTAdapter
atom_type_diffusion: mask
denoise_atom_types: true
gemnet:
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: true
condition_on_adapt:
- chemical_system
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt:
chemical_system:
_target_: mattergen.property_embeddings.PropertyEmbedding
conditional_embedding_module:
_target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
hidden_dim: 512
name: chemical_system
scaler:
_target_: torch.nn.Identity
unconditional_embedding_module:
_target_: mattergen.property_embeddings.EmbeddingVector
hidden_dim: 512
full_finetuning: true
load_epoch: last
model_path: checkpoints/mattergen_base
data_module:
_recursive_: true
_target_: mattergen.common.data.datamodule.CrystDataModule
average_density: 0.05771451654022283
batch_size:
train: 64
val: 64
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
max_epochs: 2200
num_workers:
train: 0
val: 0
properties:
- chemical_system
root_dir: datasets/cache/alex_mp_20/
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/train
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties:
- chemical_system
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/val
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties:
- chemical_system
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
diffusion_module:
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
discrete_corruptions:
atomic_numbers:
_target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
d3pm:
_target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
dim: 101
schedule:
_target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
kind: standard
num_steps: 1000
offset: 1
sdes:
cell:
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
vpsde_config:
beta_max: 20
beta_min: 0.1
limit_density: 0.05771451654022283
limit_var_scaling_constant: 0.25
pos:
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
limit_info_key: num_atoms
sigma_max: 5.0
wrapping_boundary: 1.0
loss_fn:
_target_: mattergen.common.loss.MaterialsLoss
d3pm_hybrid_lambda: 0.01
include_atomic_numbers: true
include_cell: true
include_pos: true
reduce: sum
weights:
atomic_numbers: 1.0
cell: 1.0
pos: 0.1
model:
_target_: mattergen.adapter.GemNetTAdapter
atom_type_diffusion: mask
denoise_atom_types: true
gemnet:
_target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: true
condition_on_adapt:
- chemical_system
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt:
chemical_system:
_target_: mattergen.property_embeddings.PropertyEmbedding
conditional_embedding_module:
_target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
hidden_dim: 512
name: chemical_system
scaler:
_target_: torch.nn.Identity
unconditional_embedding_module:
_target_: mattergen.property_embeddings.EmbeddingVector
hidden_dim: 512
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
dropout_fields_iid: false
p_unconditional: 0.2
optimizer_partial:
_partial_: true
_target_: torch.optim.Adam
lr: 5.0e-06
scheduler_partials:
- frequency: 1
interval: epoch
monitor: loss_train
scheduler:
_partial_: true
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
min_lr: 1.0e-06
patience: 100
verbose: true
strict: true
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
callbacks:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
log_momentum: false
logging_interval: step
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
every_n_epochs: 1
filename: '{epoch}-{loss_val:.2f}'
mode: min
monitor: loss_val
save_last: true
save_top_k: 1
verbose: false
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 50
- _target_: mattergen.common.data.callback.SetPropertyScalers
check_val_every_n_epoch: 1
devices: 8
gradient_clip_algorithm: value
gradient_clip_val: 0.5
logger:
_target_: pytorch_lightning.loggers.WandbLogger
job_type: train_finetune
project: crystal-generation
settings:
_save_requirements: false
_target_: wandb.Settings
start_method: fork
max_epochs: 200
num_nodes: 1
precision: 32
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true
|