File size: 5,485 Bytes
17e1388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
auto_resume: true
checkpoint_path: null
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.05771451654022283
  batch_size:
    train: 32
    val: 32
  max_epochs: 2200
  num_workers:
    train: 0
    val: 0
  properties:
  - dft_bulk_modulus
  - dft_band_gap
  - dft_mag_density
  - ml_bulk_modulus
  - hhi_score
  - space_group
  - energy_above_hull
  root_dir: datasets/cache/alex_mp_20/
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/train
    properties:
    - dft_bulk_modulus
    - dft_band_gap
    - dft_mag_density
    - ml_bulk_modulus
    - hhi_score
    - space_group
    - energy_above_hull
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
  transforms:
  - _partial_: true
    _target_: mattergen.common.data.transform.symmetrize_lattice
  - _partial_: true
    _target_: mattergen.common.data.transform.set_chemical_system_string
  val_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/val
    properties:
    - dft_bulk_modulus
    - dft_band_gap
    - dft_mag_density
    - ml_bulk_modulus
    - hhi_score
    - space_group
    - energy_above_hull
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
  _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
  diffusion_module:
    _target_: mattergen.diffusion.diffusion_module.DiffusionModule
    corruption:
      _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
      discrete_corruptions:
        atomic_numbers:
          _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
          d3pm:
            _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
            dim: 101
            schedule:
              _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
              kind: standard
              num_steps: 1000
          offset: 1
      sdes:
        cell:
          _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
          vpsde_config:
            beta_max: 20
            beta_min: 0.1
            limit_density: 0.05771451654022283
            limit_var_scaling_constant: 0.25
        pos:
          _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
          limit_info_key: num_atoms
          sigma_max: 5.0
          wrapping_boundary: 1.0
    loss_fn:
      _target_: mattergen.common.loss.MaterialsLoss
      d3pm_hybrid_lambda: 0.01
      include_atomic_numbers: true
      include_cell: true
      include_pos: true
      reduce: sum
      weights:
        atomic_numbers: 1.0
        cell: 1.0
        pos: 0.1
    model:
      _target_: mattergen.denoiser.GemNetTDenoiser
      atom_type_diffusion: mask
      denoise_atom_types: true
      gemnet:
        _target_: mattergen.common.gemnet.gemnet.GemNetT
        atom_embedding:
          _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
          emb_size: 512
          with_mask_type: true
        cutoff: 7.0
        emb_size_atom: 512
        emb_size_edge: 512
        latent_dim: 512
        max_cell_images_per_dim: 5
        max_neighbors: 50
        num_blocks: 4
        num_targets: 1
        otf_graph: true
        regress_stress: true
        scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
      hidden_dim: 512
      property_embeddings: {}
      property_embeddings_adapt: {}
    pre_corruption_fn:
      _target_: mattergen.property_embeddings.SetEmbeddingType
      dropout_fields_iid: false
      p_unconditional: 0.2
  optimizer_partial:
    _partial_: true
    _target_: torch.optim.Adam
    lr: 0.0001
  scheduler_partials:
  - frequency: 1
    interval: epoch
    monitor: loss_train
    scheduler:
      _partial_: true
      _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
      factor: 0.6
      min_lr: 1.0e-06
      patience: 100
      verbose: true
    strict: true
load_original: false
params: {}
train: true
trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  accumulate_grad_batches: 1
  callbacks:
  - _target_: pytorch_lightning.callbacks.LearningRateMonitor
    log_momentum: false
    logging_interval: step
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    every_n_epochs: 1
    filename: '{epoch}-{loss_val:.2f}'
    mode: min
    monitor: loss_val
    save_last: true
    save_top_k: 1
    verbose: false
  - _target_: pytorch_lightning.callbacks.TQDMProgressBar
    refresh_rate: 50
  - _target_: mattergen.common.data.callback.SetPropertyScalers
  check_val_every_n_epoch: 5
  devices: 8
  gradient_clip_algorithm: value
  gradient_clip_val: 0.5
  logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    job_type: train
    project: crystal-generation
    settings:
      _save_requirements: false
      _target_: wandb.Settings
      start_method: fork
  max_epochs: 2200
  num_nodes: 2
  precision: 32
  strategy:
    _target_: pytorch_lightning.strategies.ddp.DDPStrategy
    find_unused_parameters: true