File size: 5,485 Bytes
17e1388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
auto_resume: true
checkpoint_path: null
data_module:
_recursive_: true
_target_: mattergen.common.data.datamodule.CrystDataModule
average_density: 0.05771451654022283
batch_size:
train: 32
val: 32
max_epochs: 2200
num_workers:
train: 0
val: 0
properties:
- dft_bulk_modulus
- dft_band_gap
- dft_mag_density
- ml_bulk_modulus
- hhi_score
- space_group
- energy_above_hull
root_dir: datasets/cache/alex_mp_20/
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/train
properties:
- dft_bulk_modulus
- dft_band_gap
- dft_mag_density
- ml_bulk_modulus
- hhi_score
- space_group
- energy_above_hull
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: datasets/cache/alex_mp_20/val
properties:
- dft_bulk_modulus
- dft_band_gap
- dft_mag_density
- ml_bulk_modulus
- hhi_score
- space_group
- energy_above_hull
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
diffusion_module:
_target_: mattergen.diffusion.diffusion_module.DiffusionModule
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
discrete_corruptions:
atomic_numbers:
_target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
d3pm:
_target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
dim: 101
schedule:
_target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
kind: standard
num_steps: 1000
offset: 1
sdes:
cell:
_target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
vpsde_config:
beta_max: 20
beta_min: 0.1
limit_density: 0.05771451654022283
limit_var_scaling_constant: 0.25
pos:
_target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
limit_info_key: num_atoms
sigma_max: 5.0
wrapping_boundary: 1.0
loss_fn:
_target_: mattergen.common.loss.MaterialsLoss
d3pm_hybrid_lambda: 0.01
include_atomic_numbers: true
include_cell: true
include_pos: true
reduce: sum
weights:
atomic_numbers: 1.0
cell: 1.0
pos: 0.1
model:
_target_: mattergen.denoiser.GemNetTDenoiser
atom_type_diffusion: mask
denoise_atom_types: true
gemnet:
_target_: mattergen.common.gemnet.gemnet.GemNetT
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: true
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt: {}
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
dropout_fields_iid: false
p_unconditional: 0.2
optimizer_partial:
_partial_: true
_target_: torch.optim.Adam
lr: 0.0001
scheduler_partials:
- frequency: 1
interval: epoch
monitor: loss_train
scheduler:
_partial_: true
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
min_lr: 1.0e-06
patience: 100
verbose: true
strict: true
load_original: false
params: {}
train: true
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
callbacks:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
log_momentum: false
logging_interval: step
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
every_n_epochs: 1
filename: '{epoch}-{loss_val:.2f}'
mode: min
monitor: loss_val
save_last: true
save_top_k: 1
verbose: false
- _target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 50
- _target_: mattergen.common.data.callback.SetPropertyScalers
check_val_every_n_epoch: 5
devices: 8
gradient_clip_algorithm: value
gradient_clip_val: 0.5
logger:
_target_: pytorch_lightning.loggers.WandbLogger
job_type: train
project: crystal-generation
settings:
_save_requirements: false
_target_: wandb.Settings
start_method: fork
max_epochs: 2200
num_nodes: 2
precision: 32
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true
|