File size: 7,336 Bytes
17e1388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
adapter:
  adapter:
    _target_: mattergen.adapter.GemNetTAdapter
    atom_type_diffusion: mask
    denoise_atom_types: true
    gemnet:
      _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
      atom_embedding:
        _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
        emb_size: 512
        with_mask_type: true
      condition_on_adapt:
      - ml_bulk_modulus
      cutoff: 7.0
      emb_size_atom: 512
      emb_size_edge: 512
      latent_dim: 512
      max_cell_images_per_dim: 5
      max_neighbors: 50
      num_blocks: 4
      num_targets: 1
      otf_graph: true
      regress_stress: true
      scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
    hidden_dim: 512
    property_embeddings: {}
    property_embeddings_adapt:
      ml_bulk_modulus:
        _target_: mattergen.property_embeddings.PropertyEmbedding
        conditional_embedding_module:
          _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
          d_model: 512
        name: ml_bulk_modulus
        scaler:
          _target_: mattergen.common.utils.data_utils.StandardScalerTorch
        unconditional_embedding_module:
          _target_: mattergen.property_embeddings.EmbeddingVector
          hidden_dim: 512
  full_finetuning: true
  load_epoch: last
  model_path: checkpoints/mattergen_base
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.05771451654022283
  batch_size:
    train: 64
    val: 64
  dataset_transforms:
  - _partial_: true
    _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
  max_epochs: 2200
  num_workers:
    train: 0
    val: 0
  properties:
  - ml_bulk_modulus
  root_dir: datasets/cache/alex_mp_20
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/train
    dataset_transforms:
    - _partial_: true
      _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    properties:
    - ml_bulk_modulus
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
  transforms:
  - _partial_: true
    _target_: mattergen.common.data.transform.symmetrize_lattice
  - _partial_: true
    _target_: mattergen.common.data.transform.set_chemical_system_string
  val_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/val
    dataset_transforms:
    - _partial_: true
      _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    properties:
    - ml_bulk_modulus
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
  _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
  diffusion_module:
    _target_: mattergen.diffusion.diffusion_module.DiffusionModule
    corruption:
      _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
      discrete_corruptions:
        atomic_numbers:
          _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
          d3pm:
            _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
            dim: 101
            schedule:
              _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
              kind: standard
              num_steps: 1000
          offset: 1
      sdes:
        cell:
          _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
          vpsde_config:
            beta_max: 20
            beta_min: 0.1
            limit_density: 0.05771451654022283
            limit_var_scaling_constant: 0.25
        pos:
          _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
          limit_info_key: num_atoms
          sigma_max: 5.0
          wrapping_boundary: 1.0
    loss_fn:
      _target_: mattergen.common.loss.MaterialsLoss
      d3pm_hybrid_lambda: 0.01
      include_atomic_numbers: true
      include_cell: true
      include_pos: true
      reduce: sum
      weights:
        atomic_numbers: 1.0
        cell: 1.0
        pos: 0.1
    model:
      _target_: mattergen.adapter.GemNetTAdapter
      atom_type_diffusion: mask
      denoise_atom_types: true
      gemnet:
        _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
        atom_embedding:
          _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
          emb_size: 512
          with_mask_type: true
        condition_on_adapt:
        - ml_bulk_modulus
        cutoff: 7.0
        emb_size_atom: 512
        emb_size_edge: 512
        latent_dim: 512
        max_cell_images_per_dim: 5
        max_neighbors: 50
        num_blocks: 4
        num_targets: 1
        otf_graph: true
        regress_stress: true
        scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
      hidden_dim: 512
      property_embeddings: {}
      property_embeddings_adapt:
        ml_bulk_modulus:
          _target_: mattergen.property_embeddings.PropertyEmbedding
          conditional_embedding_module:
            _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
            d_model: 512
          name: ml_bulk_modulus
          scaler:
            _target_: mattergen.common.utils.data_utils.StandardScalerTorch
          unconditional_embedding_module:
            _target_: mattergen.property_embeddings.EmbeddingVector
            hidden_dim: 512
    pre_corruption_fn:
      _target_: mattergen.property_embeddings.SetEmbeddingType
      dropout_fields_iid: false
      p_unconditional: 0.2
  optimizer_partial:
    _partial_: true
    _target_: torch.optim.Adam
    lr: 5.0e-06
  scheduler_partials:
  - frequency: 1
    interval: epoch
    monitor: loss_train
    scheduler:
      _partial_: true
      _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
      factor: 0.6
      min_lr: 1.0e-06
      patience: 100
      verbose: true
    strict: true
trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  accumulate_grad_batches: 1
  callbacks:
  - _target_: pytorch_lightning.callbacks.LearningRateMonitor
    log_momentum: false
    logging_interval: step
  - _target_: pytorch_lightning.callbacks.ModelCheckpoint
    every_n_epochs: 1
    filename: '{epoch}-{loss_val:.2f}'
    mode: min
    monitor: loss_val
    save_last: true
    save_top_k: 1
    verbose: false
  - _target_: pytorch_lightning.callbacks.TQDMProgressBar
    refresh_rate: 50
  - _target_: mattergen.common.data.callback.SetPropertyScalers
  check_val_every_n_epoch: 1
  devices: 8
  gradient_clip_algorithm: value
  gradient_clip_val: 0.5
  logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    job_type: train_finetune
    project: crystal-generation
    settings:
      _save_requirements: false
      _target_: wandb.Settings
      start_method: fork
  max_epochs: 200
  num_nodes: 1
  precision: 32
  strategy:
    _target_: pytorch_lightning.strategies.ddp.DDPStrategy
    find_unused_parameters: true