danielzuegner-ms commited on
Commit
17e1388
·
verified ·
1 Parent(s): 2a2d887

Upload folder using huggingface_hub

Browse files
checkpoints/chemical_system/config.yaml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - chemical_system
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ chemical_system:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
32
+ hidden_dim: 512
33
+ name: chemical_system
34
+ scaler:
35
+ _target_: torch.nn.Identity
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: checkpoints/mattergen_base
42
+ data_module:
43
+ _recursive_: true
44
+ _target_: mattergen.common.data.datamodule.CrystDataModule
45
+ average_density: 0.05771451654022283
46
+ batch_size:
47
+ train: 64
48
+ val: 64
49
+ dataset_transforms:
50
+ - _partial_: true
51
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
52
+ max_epochs: 2200
53
+ num_workers:
54
+ train: 0
55
+ val: 0
56
+ properties:
57
+ - chemical_system
58
+ root_dir: datasets/cache/alex_mp_20/
59
+ train_dataset:
60
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
61
+ cache_path: datasets/cache/alex_mp_20/train
62
+ dataset_transforms:
63
+ - _partial_: true
64
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
65
+ properties:
66
+ - chemical_system
67
+ transforms:
68
+ - _partial_: true
69
+ _target_: mattergen.common.data.transform.symmetrize_lattice
70
+ - _partial_: true
71
+ _target_: mattergen.common.data.transform.set_chemical_system_string
72
+ transforms:
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.symmetrize_lattice
75
+ - _partial_: true
76
+ _target_: mattergen.common.data.transform.set_chemical_system_string
77
+ val_dataset:
78
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
79
+ cache_path: datasets/cache/alex_mp_20/val
80
+ dataset_transforms:
81
+ - _partial_: true
82
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
83
+ properties:
84
+ - chemical_system
85
+ transforms:
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.symmetrize_lattice
88
+ - _partial_: true
89
+ _target_: mattergen.common.data.transform.set_chemical_system_string
90
+ lightning_module:
91
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
92
+ diffusion_module:
93
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
94
+ corruption:
95
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
96
+ discrete_corruptions:
97
+ atomic_numbers:
98
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
99
+ d3pm:
100
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
101
+ dim: 101
102
+ schedule:
103
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
104
+ kind: standard
105
+ num_steps: 1000
106
+ offset: 1
107
+ sdes:
108
+ cell:
109
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
110
+ vpsde_config:
111
+ beta_max: 20
112
+ beta_min: 0.1
113
+ limit_density: 0.05771451654022283
114
+ limit_var_scaling_constant: 0.25
115
+ pos:
116
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
117
+ limit_info_key: num_atoms
118
+ sigma_max: 5.0
119
+ wrapping_boundary: 1.0
120
+ loss_fn:
121
+ _target_: mattergen.common.loss.MaterialsLoss
122
+ d3pm_hybrid_lambda: 0.01
123
+ include_atomic_numbers: true
124
+ include_cell: true
125
+ include_pos: true
126
+ reduce: sum
127
+ weights:
128
+ atomic_numbers: 1.0
129
+ cell: 1.0
130
+ pos: 0.1
131
+ model:
132
+ _target_: mattergen.adapter.GemNetTAdapter
133
+ atom_type_diffusion: mask
134
+ denoise_atom_types: true
135
+ gemnet:
136
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
137
+ atom_embedding:
138
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
139
+ emb_size: 512
140
+ with_mask_type: true
141
+ condition_on_adapt:
142
+ - chemical_system
143
+ cutoff: 7.0
144
+ emb_size_atom: 512
145
+ emb_size_edge: 512
146
+ latent_dim: 512
147
+ max_cell_images_per_dim: 5
148
+ max_neighbors: 50
149
+ num_blocks: 4
150
+ num_targets: 1
151
+ otf_graph: true
152
+ regress_stress: true
153
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
154
+ hidden_dim: 512
155
+ property_embeddings: {}
156
+ property_embeddings_adapt:
157
+ chemical_system:
158
+ _target_: mattergen.property_embeddings.PropertyEmbedding
159
+ conditional_embedding_module:
160
+ _target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
161
+ hidden_dim: 512
162
+ name: chemical_system
163
+ scaler:
164
+ _target_: torch.nn.Identity
165
+ unconditional_embedding_module:
166
+ _target_: mattergen.property_embeddings.EmbeddingVector
167
+ hidden_dim: 512
168
+ pre_corruption_fn:
169
+ _target_: mattergen.property_embeddings.SetEmbeddingType
170
+ dropout_fields_iid: false
171
+ p_unconditional: 0.2
172
+ optimizer_partial:
173
+ _partial_: true
174
+ _target_: torch.optim.Adam
175
+ lr: 5.0e-06
176
+ scheduler_partials:
177
+ - frequency: 1
178
+ interval: epoch
179
+ monitor: loss_train
180
+ scheduler:
181
+ _partial_: true
182
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
183
+ factor: 0.6
184
+ min_lr: 1.0e-06
185
+ patience: 100
186
+ verbose: true
187
+ strict: true
188
+ trainer:
189
+ _target_: pytorch_lightning.Trainer
190
+ accelerator: gpu
191
+ accumulate_grad_batches: 1
192
+ callbacks:
193
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
194
+ log_momentum: false
195
+ logging_interval: step
196
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
197
+ every_n_epochs: 1
198
+ filename: '{epoch}-{loss_val:.2f}'
199
+ mode: min
200
+ monitor: loss_val
201
+ save_last: true
202
+ save_top_k: 1
203
+ verbose: false
204
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
205
+ refresh_rate: 50
206
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
207
+ check_val_every_n_epoch: 1
208
+ devices: 8
209
+ gradient_clip_algorithm: value
210
+ gradient_clip_val: 0.5
211
+ logger:
212
+ _target_: pytorch_lightning.loggers.WandbLogger
213
+ job_type: train_finetune
214
+ project: crystal-generation
215
+ settings:
216
+ _save_requirements: false
217
+ _target_: wandb.Settings
218
+ start_method: fork
219
+ max_epochs: 200
220
+ num_nodes: 1
221
+ precision: 32
222
+ strategy:
223
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
224
+ find_unused_parameters: true
checkpoints/chemical_system_energy_above_hull/config.yaml ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - chemical_system
14
+ - energy_above_hull
15
+ cutoff: 7.0
16
+ emb_size_atom: 512
17
+ emb_size_edge: 512
18
+ latent_dim: 512
19
+ max_cell_images_per_dim: 5
20
+ max_neighbors: 50
21
+ num_blocks: 4
22
+ num_targets: 1
23
+ otf_graph: true
24
+ regress_stress: true
25
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
26
+ hidden_dim: 512
27
+ property_embeddings: {}
28
+ property_embeddings_adapt:
29
+ chemical_system:
30
+ _target_: mattergen.property_embeddings.PropertyEmbedding
31
+ conditional_embedding_module:
32
+ _target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
33
+ hidden_dim: 512
34
+ name: chemical_system
35
+ scaler:
36
+ _target_: torch.nn.Identity
37
+ unconditional_embedding_module:
38
+ _target_: mattergen.property_embeddings.EmbeddingVector
39
+ hidden_dim: 512
40
+ energy_above_hull:
41
+ _target_: mattergen.property_embeddings.PropertyEmbedding
42
+ conditional_embedding_module:
43
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
44
+ d_model: 512
45
+ name: energy_above_hull
46
+ scaler:
47
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
48
+ unconditional_embedding_module:
49
+ _target_: mattergen.property_embeddings.EmbeddingVector
50
+ hidden_dim: 512
51
+ full_finetuning: true
52
+ load_epoch: last
53
+ model_path: checkpoints/mattergen_base
54
+ data_module:
55
+ _recursive_: true
56
+ _target_: mattergen.common.data.datamodule.CrystDataModule
57
+ average_density: 0.05771451654022283
58
+ batch_size:
59
+ train: 64
60
+ val: 64
61
+ dataset_transforms:
62
+ - _partial_: true
63
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
64
+ max_epochs: 2200
65
+ num_workers:
66
+ train: 0
67
+ val: 0
68
+ properties:
69
+ - chemical_system
70
+ - energy_above_hull
71
+ root_dir: datasets/cache/alex_mp_20
72
+ train_dataset:
73
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
74
+ cache_path: datasets/cache/alex_mp_20/train
75
+ dataset_transforms:
76
+ - _partial_: true
77
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
78
+ properties:
79
+ - chemical_system
80
+ - energy_above_hull
81
+ transforms:
82
+ - _partial_: true
83
+ _target_: mattergen.common.data.transform.symmetrize_lattice
84
+ - _partial_: true
85
+ _target_: mattergen.common.data.transform.set_chemical_system_string
86
+ transforms:
87
+ - _partial_: true
88
+ _target_: mattergen.common.data.transform.symmetrize_lattice
89
+ - _partial_: true
90
+ _target_: mattergen.common.data.transform.set_chemical_system_string
91
+ val_dataset:
92
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
93
+ cache_path: datasets/cache/alex_mp_20/val
94
+ dataset_transforms:
95
+ - _partial_: true
96
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
97
+ properties:
98
+ - chemical_system
99
+ - energy_above_hull
100
+ transforms:
101
+ - _partial_: true
102
+ _target_: mattergen.common.data.transform.symmetrize_lattice
103
+ - _partial_: true
104
+ _target_: mattergen.common.data.transform.set_chemical_system_string
105
+ lightning_module:
106
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
107
+ diffusion_module:
108
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
109
+ corruption:
110
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
111
+ discrete_corruptions:
112
+ atomic_numbers:
113
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
114
+ d3pm:
115
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
116
+ dim: 101
117
+ schedule:
118
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
119
+ kind: standard
120
+ num_steps: 1000
121
+ offset: 1
122
+ sdes:
123
+ cell:
124
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
125
+ vpsde_config:
126
+ beta_max: 20
127
+ beta_min: 0.1
128
+ limit_density: 0.05771451654022283
129
+ limit_var_scaling_constant: 0.25
130
+ pos:
131
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
132
+ limit_info_key: num_atoms
133
+ sigma_max: 5.0
134
+ wrapping_boundary: 1.0
135
+ loss_fn:
136
+ _target_: mattergen.common.loss.MaterialsLoss
137
+ d3pm_hybrid_lambda: 0.01
138
+ include_atomic_numbers: true
139
+ include_cell: true
140
+ include_pos: true
141
+ reduce: sum
142
+ weights:
143
+ atomic_numbers: 1.0
144
+ cell: 1.0
145
+ pos: 0.1
146
+ model:
147
+ _target_: mattergen.adapter.GemNetTAdapter
148
+ atom_type_diffusion: mask
149
+ denoise_atom_types: true
150
+ gemnet:
151
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
152
+ atom_embedding:
153
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
154
+ emb_size: 512
155
+ with_mask_type: true
156
+ condition_on_adapt:
157
+ - chemical_system
158
+ - energy_above_hull
159
+ cutoff: 7.0
160
+ emb_size_atom: 512
161
+ emb_size_edge: 512
162
+ latent_dim: 512
163
+ max_cell_images_per_dim: 5
164
+ max_neighbors: 50
165
+ num_blocks: 4
166
+ num_targets: 1
167
+ otf_graph: true
168
+ regress_stress: true
169
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
170
+ hidden_dim: 512
171
+ property_embeddings: {}
172
+ property_embeddings_adapt:
173
+ chemical_system:
174
+ _target_: mattergen.property_embeddings.PropertyEmbedding
175
+ conditional_embedding_module:
176
+ _target_: mattergen.property_embeddings.ChemicalSystemMultiHotEmbedding
177
+ hidden_dim: 512
178
+ name: chemical_system
179
+ scaler:
180
+ _target_: torch.nn.Identity
181
+ unconditional_embedding_module:
182
+ _target_: mattergen.property_embeddings.EmbeddingVector
183
+ hidden_dim: 512
184
+ energy_above_hull:
185
+ _target_: mattergen.property_embeddings.PropertyEmbedding
186
+ conditional_embedding_module:
187
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
188
+ d_model: 512
189
+ name: energy_above_hull
190
+ scaler:
191
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
192
+ unconditional_embedding_module:
193
+ _target_: mattergen.property_embeddings.EmbeddingVector
194
+ hidden_dim: 512
195
+ pre_corruption_fn:
196
+ _target_: mattergen.property_embeddings.SetEmbeddingType
197
+ dropout_fields_iid: false
198
+ p_unconditional: 0.2
199
+ optimizer_partial:
200
+ _partial_: true
201
+ _target_: torch.optim.Adam
202
+ lr: 5.0e-06
203
+ scheduler_partials:
204
+ - frequency: 1
205
+ interval: epoch
206
+ monitor: loss_train
207
+ scheduler:
208
+ _partial_: true
209
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
210
+ factor: 0.6
211
+ min_lr: 1.0e-06
212
+ patience: 100
213
+ verbose: true
214
+ strict: true
215
+ trainer:
216
+ _target_: pytorch_lightning.Trainer
217
+ accelerator: gpu
218
+ accumulate_grad_batches: 1
219
+ callbacks:
220
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
221
+ log_momentum: false
222
+ logging_interval: step
223
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
224
+ every_n_epochs: 1
225
+ filename: '{epoch}-{loss_val:.2f}'
226
+ mode: min
227
+ monitor: loss_val
228
+ save_last: true
229
+ save_top_k: 1
230
+ verbose: false
231
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
232
+ refresh_rate: 50
233
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
234
+ check_val_every_n_epoch: 5
235
+ devices: 8
236
+ gradient_clip_algorithm: value
237
+ gradient_clip_val: 0.5
238
+ logger:
239
+ _target_: pytorch_lightning.loggers.WandbLogger
240
+ job_type: train_finetune
241
+ project: crystal-generation
242
+ settings:
243
+ _save_requirements: false
244
+ _target_: wandb.Settings
245
+ start_method: fork
246
+ max_epochs: 200
247
+ num_nodes: 1
248
+ precision: 32
249
+ strategy:
250
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
251
+ find_unused_parameters: true
checkpoints/dft_band_gap/config.yaml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - dft_band_gap
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ dft_band_gap:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
32
+ d_model: 512
33
+ name: dft_band_gap
34
+ scaler:
35
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: checkpoints/mattergen_base
42
+ data_module:
43
+ _recursive_: true
44
+ _target_: mattergen.common.data.datamodule.CrystDataModule
45
+ average_density: 0.05771451654022283
46
+ batch_size:
47
+ train: 64
48
+ val: 64
49
+ dataset_transforms:
50
+ - _partial_: true
51
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
52
+ max_epochs: 2200
53
+ num_workers:
54
+ train: 0
55
+ val: 0
56
+ properties:
57
+ - dft_band_gap
58
+ root_dir: datasets/cache/alex_mp_20
59
+ train_dataset:
60
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
61
+ cache_path: datasets/cache/alex_mp_20/train
62
+ dataset_transforms:
63
+ - _partial_: true
64
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
65
+ properties:
66
+ - dft_band_gap
67
+ transforms:
68
+ - _partial_: true
69
+ _target_: mattergen.common.data.transform.symmetrize_lattice
70
+ - _partial_: true
71
+ _target_: mattergen.common.data.transform.set_chemical_system_string
72
+ transforms:
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.symmetrize_lattice
75
+ - _partial_: true
76
+ _target_: mattergen.common.data.transform.set_chemical_system_string
77
+ val_dataset:
78
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
79
+ cache_path: datasets/cache/alex_mp_20/val
80
+ dataset_transforms:
81
+ - _partial_: true
82
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
83
+ properties:
84
+ - dft_band_gap
85
+ transforms:
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.symmetrize_lattice
88
+ - _partial_: true
89
+ _target_: mattergen.common.data.transform.set_chemical_system_string
90
+ lightning_module:
91
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
92
+ diffusion_module:
93
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
94
+ corruption:
95
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
96
+ discrete_corruptions:
97
+ atomic_numbers:
98
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
99
+ d3pm:
100
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
101
+ dim: 101
102
+ schedule:
103
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
104
+ kind: standard
105
+ num_steps: 1000
106
+ offset: 1
107
+ sdes:
108
+ cell:
109
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
110
+ vpsde_config:
111
+ beta_max: 20
112
+ beta_min: 0.1
113
+ limit_density: 0.05771451654022283
114
+ limit_var_scaling_constant: 0.25
115
+ pos:
116
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
117
+ limit_info_key: num_atoms
118
+ sigma_max: 5.0
119
+ wrapping_boundary: 1.0
120
+ loss_fn:
121
+ _target_: mattergen.common.loss.MaterialsLoss
122
+ d3pm_hybrid_lambda: 0.01
123
+ include_atomic_numbers: true
124
+ include_cell: true
125
+ include_pos: true
126
+ reduce: sum
127
+ weights:
128
+ atomic_numbers: 1.0
129
+ cell: 1.0
130
+ pos: 0.1
131
+ model:
132
+ _target_: mattergen.adapter.GemNetTAdapter
133
+ atom_type_diffusion: mask
134
+ denoise_atom_types: true
135
+ gemnet:
136
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
137
+ atom_embedding:
138
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
139
+ emb_size: 512
140
+ with_mask_type: true
141
+ condition_on_adapt:
142
+ - dft_band_gap
143
+ cutoff: 7.0
144
+ emb_size_atom: 512
145
+ emb_size_edge: 512
146
+ latent_dim: 512
147
+ max_cell_images_per_dim: 5
148
+ max_neighbors: 50
149
+ num_blocks: 4
150
+ num_targets: 1
151
+ otf_graph: true
152
+ regress_stress: true
153
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
154
+ hidden_dim: 512
155
+ property_embeddings: {}
156
+ property_embeddings_adapt:
157
+ dft_band_gap:
158
+ _target_: mattergen.property_embeddings.PropertyEmbedding
159
+ conditional_embedding_module:
160
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
161
+ d_model: 512
162
+ name: dft_band_gap
163
+ scaler:
164
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
165
+ unconditional_embedding_module:
166
+ _target_: mattergen.property_embeddings.EmbeddingVector
167
+ hidden_dim: 512
168
+ pre_corruption_fn:
169
+ _target_: mattergen.property_embeddings.SetEmbeddingType
170
+ dropout_fields_iid: false
171
+ p_unconditional: 0.2
172
+ optimizer_partial:
173
+ _partial_: true
174
+ _target_: torch.optim.Adam
175
+ lr: 5.0e-06
176
+ scheduler_partials:
177
+ - frequency: 1
178
+ interval: epoch
179
+ monitor: loss_train
180
+ scheduler:
181
+ _partial_: true
182
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
183
+ factor: 0.6
184
+ min_lr: 1.0e-06
185
+ patience: 100
186
+ verbose: true
187
+ strict: true
188
+ trainer:
189
+ _target_: pytorch_lightning.Trainer
190
+ accelerator: gpu
191
+ accumulate_grad_batches: 1
192
+ callbacks:
193
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
194
+ log_momentum: false
195
+ logging_interval: step
196
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
197
+ every_n_epochs: 1
198
+ filename: '{epoch}-{loss_val:.2f}'
199
+ mode: min
200
+ monitor: loss_val
201
+ save_last: true
202
+ save_top_k: 1
203
+ verbose: false
204
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
205
+ refresh_rate: 50
206
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
207
+ check_val_every_n_epoch: 5
208
+ devices: 8
209
+ gradient_clip_algorithm: value
210
+ gradient_clip_val: 0.5
211
+ logger:
212
+ _target_: pytorch_lightning.loggers.WandbLogger
213
+ job_type: train_finetune
214
+ project: crystal-generation
215
+ settings:
216
+ _save_requirements: false
217
+ _target_: wandb.Settings
218
+ start_method: fork
219
+ max_epochs: 200
220
+ num_nodes: 1
221
+ precision: 32
222
+ strategy:
223
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
224
+ find_unused_parameters: true
checkpoints/dft_mag_density/config.yaml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - dft_mag_density
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ dft_mag_density:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
32
+ d_model: 512
33
+ name: dft_mag_density
34
+ scaler:
35
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: checkpoints/mattergen_base
42
+ data_module:
43
+ _recursive_: true
44
+ _target_: mattergen.common.data.datamodule.CrystDataModule
45
+ average_density: 0.05771451654022283
46
+ batch_size:
47
+ train: 64
48
+ val: 64
49
+ dataset_transforms:
50
+ - _partial_: true
51
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
52
+ max_epochs: 2200
53
+ num_workers:
54
+ train: 0
55
+ val: 0
56
+ properties:
57
+ - dft_mag_density
58
+ root_dir: datasets/cache/alex_mp_20
59
+ train_dataset:
60
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
61
+ cache_path: datasets/cache/alex_mp_20/train
62
+ dataset_transforms:
63
+ - _partial_: true
64
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
65
+ properties:
66
+ - dft_mag_density
67
+ transforms:
68
+ - _partial_: true
69
+ _target_: mattergen.common.data.transform.symmetrize_lattice
70
+ - _partial_: true
71
+ _target_: mattergen.common.data.transform.set_chemical_system_string
72
+ transforms:
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.symmetrize_lattice
75
+ - _partial_: true
76
+ _target_: mattergen.common.data.transform.set_chemical_system_string
77
+ val_dataset:
78
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
79
+ cache_path: datasets/cache/alex_mp_20/val
80
+ dataset_transforms:
81
+ - _partial_: true
82
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
83
+ properties:
84
+ - dft_mag_density
85
+ transforms:
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.symmetrize_lattice
88
+ - _partial_: true
89
+ _target_: mattergen.common.data.transform.set_chemical_system_string
90
+ lightning_module:
91
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
92
+ diffusion_module:
93
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
94
+ corruption:
95
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
96
+ discrete_corruptions:
97
+ atomic_numbers:
98
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
99
+ d3pm:
100
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
101
+ dim: 101
102
+ schedule:
103
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
104
+ kind: standard
105
+ num_steps: 1000
106
+ offset: 1
107
+ sdes:
108
+ cell:
109
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
110
+ vpsde_config:
111
+ beta_max: 20
112
+ beta_min: 0.1
113
+ limit_density: 0.05771451654022283
114
+ limit_var_scaling_constant: 0.25
115
+ pos:
116
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
117
+ limit_info_key: num_atoms
118
+ sigma_max: 5.0
119
+ wrapping_boundary: 1.0
120
+ loss_fn:
121
+ _target_: mattergen.common.loss.MaterialsLoss
122
+ d3pm_hybrid_lambda: 0.01
123
+ include_atomic_numbers: true
124
+ include_cell: true
125
+ include_pos: true
126
+ reduce: sum
127
+ weights:
128
+ atomic_numbers: 1.0
129
+ cell: 1.0
130
+ pos: 0.1
131
+ model:
132
+ _target_: mattergen.adapter.GemNetTAdapter
133
+ atom_type_diffusion: mask
134
+ denoise_atom_types: true
135
+ gemnet:
136
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
137
+ atom_embedding:
138
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
139
+ emb_size: 512
140
+ with_mask_type: true
141
+ condition_on_adapt:
142
+ - dft_mag_density
143
+ cutoff: 7.0
144
+ emb_size_atom: 512
145
+ emb_size_edge: 512
146
+ latent_dim: 512
147
+ max_cell_images_per_dim: 5
148
+ max_neighbors: 50
149
+ num_blocks: 4
150
+ num_targets: 1
151
+ otf_graph: true
152
+ regress_stress: true
153
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
154
+ hidden_dim: 512
155
+ property_embeddings: {}
156
+ property_embeddings_adapt:
157
+ dft_mag_density:
158
+ _target_: mattergen.property_embeddings.PropertyEmbedding
159
+ conditional_embedding_module:
160
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
161
+ d_model: 512
162
+ name: dft_mag_density
163
+ scaler:
164
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
165
+ unconditional_embedding_module:
166
+ _target_: mattergen.property_embeddings.EmbeddingVector
167
+ hidden_dim: 512
168
+ pre_corruption_fn:
169
+ _target_: mattergen.property_embeddings.SetEmbeddingType
170
+ dropout_fields_iid: false
171
+ p_unconditional: 0.2
172
+ optimizer_partial:
173
+ _partial_: true
174
+ _target_: torch.optim.Adam
175
+ lr: 5.0e-06
176
+ scheduler_partials:
177
+ - frequency: 1
178
+ interval: epoch
179
+ monitor: loss_train
180
+ scheduler:
181
+ _partial_: true
182
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
183
+ factor: 0.6
184
+ min_lr: 1.0e-06
185
+ patience: 100
186
+ verbose: true
187
+ strict: true
188
+ trainer:
189
+ _target_: pytorch_lightning.Trainer
190
+ accelerator: gpu
191
+ accumulate_grad_batches: 1
192
+ callbacks:
193
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
194
+ log_momentum: false
195
+ logging_interval: step
196
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
197
+ every_n_epochs: 1
198
+ filename: '{epoch}-{loss_val:.2f}'
199
+ mode: min
200
+ monitor: loss_val
201
+ save_last: true
202
+ save_top_k: 1
203
+ verbose: false
204
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
205
+ refresh_rate: 50
206
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
207
+ check_val_every_n_epoch: 1
208
+ devices: 8
209
+ gradient_clip_algorithm: value
210
+ gradient_clip_val: 0.5
211
+ logger:
212
+ _target_: pytorch_lightning.loggers.WandbLogger
213
+ job_type: train_finetune
214
+ project: crystal-generation
215
+ settings:
216
+ _save_requirements: false
217
+ _target_: wandb.Settings
218
+ start_method: fork
219
+ max_epochs: 200
220
+ num_nodes: 1
221
+ precision: 32
222
+ strategy:
223
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
224
+ find_unused_parameters: true
checkpoints/dft_mag_density_hhi_score/config.yaml ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - dft_mag_density
14
+ - hhi_score
15
+ cutoff: 7.0
16
+ emb_size_atom: 512
17
+ emb_size_edge: 512
18
+ latent_dim: 512
19
+ max_cell_images_per_dim: 5
20
+ max_neighbors: 50
21
+ num_blocks: 4
22
+ num_targets: 1
23
+ otf_graph: true
24
+ regress_stress: true
25
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
26
+ hidden_dim: 512
27
+ property_embeddings: {}
28
+ property_embeddings_adapt:
29
+ dft_mag_density:
30
+ _target_: mattergen.property_embeddings.PropertyEmbedding
31
+ conditional_embedding_module:
32
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
33
+ d_model: 512
34
+ name: dft_mag_density
35
+ scaler:
36
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
37
+ unconditional_embedding_module:
38
+ _target_: mattergen.property_embeddings.EmbeddingVector
39
+ hidden_dim: 512
40
+ hhi_score:
41
+ _target_: mattergen.property_embeddings.PropertyEmbedding
42
+ conditional_embedding_module:
43
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
44
+ d_model: 512
45
+ name: hhi_score
46
+ scaler:
47
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
48
+ unconditional_embedding_module:
49
+ _target_: mattergen.property_embeddings.EmbeddingVector
50
+ hidden_dim: 512
51
+ full_finetuning: true
52
+ load_epoch: last
53
+ model_path: checkpoints/mattergen_base
54
+ data_module:
55
+ _recursive_: true
56
+ _target_: mattergen.common.data.datamodule.CrystDataModule
57
+ average_density: 0.05771451654022283
58
+ batch_size:
59
+ train: 64
60
+ val: 64
61
+ dataset_transforms:
62
+ - _partial_: true
63
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
64
+ max_epochs: 2200
65
+ num_workers:
66
+ train: 0
67
+ val: 0
68
+ properties:
69
+ - dft_mag_density
70
+ - hhi_score
71
+ root_dir: datasets/cache/alex_mp_20
72
+ train_dataset:
73
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
74
+ cache_path: datasets/cache/alex_mp_20/train
75
+ dataset_transforms:
76
+ - _partial_: true
77
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
78
+ properties:
79
+ - dft_mag_density
80
+ - hhi_score
81
+ transforms:
82
+ - _partial_: true
83
+ _target_: mattergen.common.data.transform.symmetrize_lattice
84
+ - _partial_: true
85
+ _target_: mattergen.common.data.transform.set_chemical_system_string
86
+ transforms:
87
+ - _partial_: true
88
+ _target_: mattergen.common.data.transform.symmetrize_lattice
89
+ - _partial_: true
90
+ _target_: mattergen.common.data.transform.set_chemical_system_string
91
+ val_dataset:
92
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
93
+ cache_path: datasets/cache/alex_mp_20/val
94
+ dataset_transforms:
95
+ - _partial_: true
96
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
97
+ properties:
98
+ - dft_mag_density
99
+ - hhi_score
100
+ transforms:
101
+ - _partial_: true
102
+ _target_: mattergen.common.data.transform.symmetrize_lattice
103
+ - _partial_: true
104
+ _target_: mattergen.common.data.transform.set_chemical_system_string
105
+ lightning_module:
106
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
107
+ diffusion_module:
108
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
109
+ corruption:
110
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
111
+ discrete_corruptions:
112
+ atomic_numbers:
113
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
114
+ d3pm:
115
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
116
+ dim: 101
117
+ schedule:
118
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
119
+ kind: standard
120
+ num_steps: 1000
121
+ offset: 1
122
+ sdes:
123
+ cell:
124
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
125
+ vpsde_config:
126
+ beta_max: 20
127
+ beta_min: 0.1
128
+ limit_density: 0.05771451654022283
129
+ limit_var_scaling_constant: 0.25
130
+ pos:
131
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
132
+ limit_info_key: num_atoms
133
+ sigma_max: 5.0
134
+ wrapping_boundary: 1.0
135
+ loss_fn:
136
+ _target_: mattergen.common.loss.MaterialsLoss
137
+ d3pm_hybrid_lambda: 0.01
138
+ include_atomic_numbers: true
139
+ include_cell: true
140
+ include_pos: true
141
+ reduce: sum
142
+ weights:
143
+ atomic_numbers: 1.0
144
+ cell: 1.0
145
+ pos: 0.1
146
+ model:
147
+ _target_: mattergen.adapter.GemNetTAdapter
148
+ atom_type_diffusion: mask
149
+ denoise_atom_types: true
150
+ gemnet:
151
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
152
+ atom_embedding:
153
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
154
+ emb_size: 512
155
+ with_mask_type: true
156
+ condition_on_adapt:
157
+ - dft_mag_density
158
+ - hhi_score
159
+ cutoff: 7.0
160
+ emb_size_atom: 512
161
+ emb_size_edge: 512
162
+ latent_dim: 512
163
+ max_cell_images_per_dim: 5
164
+ max_neighbors: 50
165
+ num_blocks: 4
166
+ num_targets: 1
167
+ otf_graph: true
168
+ regress_stress: true
169
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
170
+ hidden_dim: 512
171
+ property_embeddings: {}
172
+ property_embeddings_adapt:
173
+ dft_mag_density:
174
+ _target_: mattergen.property_embeddings.PropertyEmbedding
175
+ conditional_embedding_module:
176
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
177
+ d_model: 512
178
+ name: dft_mag_density
179
+ scaler:
180
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
181
+ unconditional_embedding_module:
182
+ _target_: mattergen.property_embeddings.EmbeddingVector
183
+ hidden_dim: 512
184
+ hhi_score:
185
+ _target_: mattergen.property_embeddings.PropertyEmbedding
186
+ conditional_embedding_module:
187
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
188
+ d_model: 512
189
+ name: hhi_score
190
+ scaler:
191
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
192
+ unconditional_embedding_module:
193
+ _target_: mattergen.property_embeddings.EmbeddingVector
194
+ hidden_dim: 512
195
+ pre_corruption_fn:
196
+ _target_: mattergen.property_embeddings.SetEmbeddingType
197
+ dropout_fields_iid: false
198
+ p_unconditional: 0.2
199
+ optimizer_partial:
200
+ _partial_: true
201
+ _target_: torch.optim.Adam
202
+ lr: 5.0e-06
203
+ scheduler_partials:
204
+ - frequency: 1
205
+ interval: epoch
206
+ monitor: loss_train
207
+ scheduler:
208
+ _partial_: true
209
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
210
+ factor: 0.6
211
+ min_lr: 1.0e-06
212
+ patience: 100
213
+ verbose: true
214
+ strict: true
215
+ trainer:
216
+ _target_: pytorch_lightning.Trainer
217
+ accelerator: gpu
218
+ accumulate_grad_batches: 1
219
+ callbacks:
220
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
221
+ log_momentum: false
222
+ logging_interval: step
223
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
224
+ every_n_epochs: 1
225
+ filename: '{epoch}-{loss_val:.2f}'
226
+ mode: min
227
+ monitor: loss_val
228
+ save_last: true
229
+ save_top_k: 1
230
+ verbose: false
231
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
232
+ refresh_rate: 50
233
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
234
+ check_val_every_n_epoch: 5
235
+ devices: 8
236
+ gradient_clip_algorithm: value
237
+ gradient_clip_val: 0.5
238
+ logger:
239
+ _target_: pytorch_lightning.loggers.WandbLogger
240
+ job_type: train_finetune
241
+ project: crystal-generation
242
+ settings:
243
+ _save_requirements: false
244
+ _target_: wandb.Settings
245
+ start_method: fork
246
+ max_epochs: 200
247
+ num_nodes: 1
248
+ precision: 32
249
+ strategy:
250
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
251
+ find_unused_parameters: true
checkpoints/mattergen_base/config.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_resume: true
2
+ checkpoint_path: null
3
+ data_module:
4
+ _recursive_: true
5
+ _target_: mattergen.common.data.datamodule.CrystDataModule
6
+ average_density: 0.05771451654022283
7
+ batch_size:
8
+ train: 32
9
+ val: 32
10
+ max_epochs: 2200
11
+ num_workers:
12
+ train: 0
13
+ val: 0
14
+ properties:
15
+ - dft_bulk_modulus
16
+ - dft_band_gap
17
+ - dft_mag_density
18
+ - ml_bulk_modulus
19
+ - hhi_score
20
+ - space_group
21
+ - energy_above_hull
22
+ root_dir: datasets/cache/alex_mp_20/
23
+ train_dataset:
24
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
25
+ cache_path: datasets/cache/alex_mp_20/train
26
+ properties:
27
+ - dft_bulk_modulus
28
+ - dft_band_gap
29
+ - dft_mag_density
30
+ - ml_bulk_modulus
31
+ - hhi_score
32
+ - space_group
33
+ - energy_above_hull
34
+ transforms:
35
+ - _partial_: true
36
+ _target_: mattergen.common.data.transform.symmetrize_lattice
37
+ - _partial_: true
38
+ _target_: mattergen.common.data.transform.set_chemical_system_string
39
+ transforms:
40
+ - _partial_: true
41
+ _target_: mattergen.common.data.transform.symmetrize_lattice
42
+ - _partial_: true
43
+ _target_: mattergen.common.data.transform.set_chemical_system_string
44
+ val_dataset:
45
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
46
+ cache_path: datasets/cache/alex_mp_20/val
47
+ properties:
48
+ - dft_bulk_modulus
49
+ - dft_band_gap
50
+ - dft_mag_density
51
+ - ml_bulk_modulus
52
+ - hhi_score
53
+ - space_group
54
+ - energy_above_hull
55
+ transforms:
56
+ - _partial_: true
57
+ _target_: mattergen.common.data.transform.symmetrize_lattice
58
+ - _partial_: true
59
+ _target_: mattergen.common.data.transform.set_chemical_system_string
60
+ lightning_module:
61
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
62
+ diffusion_module:
63
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
64
+ corruption:
65
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
66
+ discrete_corruptions:
67
+ atomic_numbers:
68
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
69
+ d3pm:
70
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
71
+ dim: 101
72
+ schedule:
73
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
74
+ kind: standard
75
+ num_steps: 1000
76
+ offset: 1
77
+ sdes:
78
+ cell:
79
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
80
+ vpsde_config:
81
+ beta_max: 20
82
+ beta_min: 0.1
83
+ limit_density: 0.05771451654022283
84
+ limit_var_scaling_constant: 0.25
85
+ pos:
86
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
87
+ limit_info_key: num_atoms
88
+ sigma_max: 5.0
89
+ wrapping_boundary: 1.0
90
+ loss_fn:
91
+ _target_: mattergen.common.loss.MaterialsLoss
92
+ d3pm_hybrid_lambda: 0.01
93
+ include_atomic_numbers: true
94
+ include_cell: true
95
+ include_pos: true
96
+ reduce: sum
97
+ weights:
98
+ atomic_numbers: 1.0
99
+ cell: 1.0
100
+ pos: 0.1
101
+ model:
102
+ _target_: mattergen.denoiser.GemNetTDenoiser
103
+ atom_type_diffusion: mask
104
+ denoise_atom_types: true
105
+ gemnet:
106
+ _target_: mattergen.common.gemnet.gemnet.GemNetT
107
+ atom_embedding:
108
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
109
+ emb_size: 512
110
+ with_mask_type: true
111
+ cutoff: 7.0
112
+ emb_size_atom: 512
113
+ emb_size_edge: 512
114
+ latent_dim: 512
115
+ max_cell_images_per_dim: 5
116
+ max_neighbors: 50
117
+ num_blocks: 4
118
+ num_targets: 1
119
+ otf_graph: true
120
+ regress_stress: true
121
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
122
+ hidden_dim: 512
123
+ property_embeddings: {}
124
+ property_embeddings_adapt: {}
125
+ pre_corruption_fn:
126
+ _target_: mattergen.property_embeddings.SetEmbeddingType
127
+ dropout_fields_iid: false
128
+ p_unconditional: 0.2
129
+ optimizer_partial:
130
+ _partial_: true
131
+ _target_: torch.optim.Adam
132
+ lr: 0.0001
133
+ scheduler_partials:
134
+ - frequency: 1
135
+ interval: epoch
136
+ monitor: loss_train
137
+ scheduler:
138
+ _partial_: true
139
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
140
+ factor: 0.6
141
+ min_lr: 1.0e-06
142
+ patience: 100
143
+ verbose: true
144
+ strict: true
145
+ load_original: false
146
+ params: {}
147
+ train: true
148
+ trainer:
149
+ _target_: pytorch_lightning.Trainer
150
+ accelerator: gpu
151
+ accumulate_grad_batches: 1
152
+ callbacks:
153
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
154
+ log_momentum: false
155
+ logging_interval: step
156
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
157
+ every_n_epochs: 1
158
+ filename: '{epoch}-{loss_val:.2f}'
159
+ mode: min
160
+ monitor: loss_val
161
+ save_last: true
162
+ save_top_k: 1
163
+ verbose: false
164
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
165
+ refresh_rate: 50
166
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
167
+ check_val_every_n_epoch: 5
168
+ devices: 8
169
+ gradient_clip_algorithm: value
170
+ gradient_clip_val: 0.5
171
+ logger:
172
+ _target_: pytorch_lightning.loggers.WandbLogger
173
+ job_type: train
174
+ project: crystal-generation
175
+ settings:
176
+ _save_requirements: false
177
+ _target_: wandb.Settings
178
+ start_method: fork
179
+ max_epochs: 2200
180
+ num_nodes: 2
181
+ precision: 32
182
+ strategy:
183
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
184
+ find_unused_parameters: true
checkpoints/ml_bulk_modulus/config.yaml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - ml_bulk_modulus
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ ml_bulk_modulus:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
32
+ d_model: 512
33
+ name: ml_bulk_modulus
34
+ scaler:
35
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: checkpoints/mattergen_base
42
+ data_module:
43
+ _recursive_: true
44
+ _target_: mattergen.common.data.datamodule.CrystDataModule
45
+ average_density: 0.05771451654022283
46
+ batch_size:
47
+ train: 64
48
+ val: 64
49
+ dataset_transforms:
50
+ - _partial_: true
51
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
52
+ max_epochs: 2200
53
+ num_workers:
54
+ train: 0
55
+ val: 0
56
+ properties:
57
+ - ml_bulk_modulus
58
+ root_dir: datasets/cache/alex_mp_20
59
+ train_dataset:
60
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
61
+ cache_path: datasets/cache/alex_mp_20/train
62
+ dataset_transforms:
63
+ - _partial_: true
64
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
65
+ properties:
66
+ - ml_bulk_modulus
67
+ transforms:
68
+ - _partial_: true
69
+ _target_: mattergen.common.data.transform.symmetrize_lattice
70
+ - _partial_: true
71
+ _target_: mattergen.common.data.transform.set_chemical_system_string
72
+ transforms:
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.symmetrize_lattice
75
+ - _partial_: true
76
+ _target_: mattergen.common.data.transform.set_chemical_system_string
77
+ val_dataset:
78
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
79
+ cache_path: datasets/cache/alex_mp_20/val
80
+ dataset_transforms:
81
+ - _partial_: true
82
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
83
+ properties:
84
+ - ml_bulk_modulus
85
+ transforms:
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.symmetrize_lattice
88
+ - _partial_: true
89
+ _target_: mattergen.common.data.transform.set_chemical_system_string
90
+ lightning_module:
91
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
92
+ diffusion_module:
93
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
94
+ corruption:
95
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
96
+ discrete_corruptions:
97
+ atomic_numbers:
98
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
99
+ d3pm:
100
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
101
+ dim: 101
102
+ schedule:
103
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
104
+ kind: standard
105
+ num_steps: 1000
106
+ offset: 1
107
+ sdes:
108
+ cell:
109
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
110
+ vpsde_config:
111
+ beta_max: 20
112
+ beta_min: 0.1
113
+ limit_density: 0.05771451654022283
114
+ limit_var_scaling_constant: 0.25
115
+ pos:
116
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
117
+ limit_info_key: num_atoms
118
+ sigma_max: 5.0
119
+ wrapping_boundary: 1.0
120
+ loss_fn:
121
+ _target_: mattergen.common.loss.MaterialsLoss
122
+ d3pm_hybrid_lambda: 0.01
123
+ include_atomic_numbers: true
124
+ include_cell: true
125
+ include_pos: true
126
+ reduce: sum
127
+ weights:
128
+ atomic_numbers: 1.0
129
+ cell: 1.0
130
+ pos: 0.1
131
+ model:
132
+ _target_: mattergen.adapter.GemNetTAdapter
133
+ atom_type_diffusion: mask
134
+ denoise_atom_types: true
135
+ gemnet:
136
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
137
+ atom_embedding:
138
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
139
+ emb_size: 512
140
+ with_mask_type: true
141
+ condition_on_adapt:
142
+ - ml_bulk_modulus
143
+ cutoff: 7.0
144
+ emb_size_atom: 512
145
+ emb_size_edge: 512
146
+ latent_dim: 512
147
+ max_cell_images_per_dim: 5
148
+ max_neighbors: 50
149
+ num_blocks: 4
150
+ num_targets: 1
151
+ otf_graph: true
152
+ regress_stress: true
153
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
154
+ hidden_dim: 512
155
+ property_embeddings: {}
156
+ property_embeddings_adapt:
157
+ ml_bulk_modulus:
158
+ _target_: mattergen.property_embeddings.PropertyEmbedding
159
+ conditional_embedding_module:
160
+ _target_: mattergen.diffusion.model_utils.NoiseLevelEncoding
161
+ d_model: 512
162
+ name: ml_bulk_modulus
163
+ scaler:
164
+ _target_: mattergen.common.utils.data_utils.StandardScalerTorch
165
+ unconditional_embedding_module:
166
+ _target_: mattergen.property_embeddings.EmbeddingVector
167
+ hidden_dim: 512
168
+ pre_corruption_fn:
169
+ _target_: mattergen.property_embeddings.SetEmbeddingType
170
+ dropout_fields_iid: false
171
+ p_unconditional: 0.2
172
+ optimizer_partial:
173
+ _partial_: true
174
+ _target_: torch.optim.Adam
175
+ lr: 5.0e-06
176
+ scheduler_partials:
177
+ - frequency: 1
178
+ interval: epoch
179
+ monitor: loss_train
180
+ scheduler:
181
+ _partial_: true
182
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
183
+ factor: 0.6
184
+ min_lr: 1.0e-06
185
+ patience: 100
186
+ verbose: true
187
+ strict: true
188
+ trainer:
189
+ _target_: pytorch_lightning.Trainer
190
+ accelerator: gpu
191
+ accumulate_grad_batches: 1
192
+ callbacks:
193
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
194
+ log_momentum: false
195
+ logging_interval: step
196
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
197
+ every_n_epochs: 1
198
+ filename: '{epoch}-{loss_val:.2f}'
199
+ mode: min
200
+ monitor: loss_val
201
+ save_last: true
202
+ save_top_k: 1
203
+ verbose: false
204
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
205
+ refresh_rate: 50
206
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
207
+ check_val_every_n_epoch: 1
208
+ devices: 8
209
+ gradient_clip_algorithm: value
210
+ gradient_clip_val: 0.5
211
+ logger:
212
+ _target_: pytorch_lightning.loggers.WandbLogger
213
+ job_type: train_finetune
214
+ project: crystal-generation
215
+ settings:
216
+ _save_requirements: false
217
+ _target_: wandb.Settings
218
+ start_method: fork
219
+ max_epochs: 200
220
+ num_nodes: 1
221
+ precision: 32
222
+ strategy:
223
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
224
+ find_unused_parameters: true
checkpoints/space_group/config.yaml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adapter:
2
+ adapter:
3
+ _target_: mattergen.adapter.GemNetTAdapter
4
+ atom_type_diffusion: mask
5
+ denoise_atom_types: true
6
+ gemnet:
7
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
8
+ atom_embedding:
9
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
10
+ emb_size: 512
11
+ with_mask_type: true
12
+ condition_on_adapt:
13
+ - space_group
14
+ cutoff: 7.0
15
+ emb_size_atom: 512
16
+ emb_size_edge: 512
17
+ latent_dim: 512
18
+ max_cell_images_per_dim: 5
19
+ max_neighbors: 50
20
+ num_blocks: 4
21
+ num_targets: 1
22
+ otf_graph: true
23
+ regress_stress: true
24
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
25
+ hidden_dim: 512
26
+ property_embeddings: {}
27
+ property_embeddings_adapt:
28
+ space_group:
29
+ _target_: mattergen.property_embeddings.PropertyEmbedding
30
+ conditional_embedding_module:
31
+ _target_: mattergen.property_embeddings.SpaceGroupEmbeddingVector
32
+ hidden_dim: 512
33
+ name: space_group
34
+ scaler:
35
+ _target_: torch.nn.Identity
36
+ unconditional_embedding_module:
37
+ _target_: mattergen.property_embeddings.EmbeddingVector
38
+ hidden_dim: 512
39
+ full_finetuning: true
40
+ load_epoch: last
41
+ model_path: checkpoints/mattergen_base
42
+ data_module:
43
+ _recursive_: true
44
+ _target_: mattergen.common.data.datamodule.CrystDataModule
45
+ average_density: 0.05771451654022283
46
+ batch_size:
47
+ train: 64
48
+ val: 64
49
+ dataset_transforms:
50
+ - _partial_: true
51
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
52
+ max_epochs: 2200
53
+ num_workers:
54
+ train: 0
55
+ val: 0
56
+ properties:
57
+ - space_group
58
+ root_dir: datasets/cache/alex_mp_20
59
+ train_dataset:
60
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
61
+ cache_path: datasets/cache/alex_mp_20/train
62
+ dataset_transforms:
63
+ - _partial_: true
64
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
65
+ properties:
66
+ - space_group
67
+ transforms:
68
+ - _partial_: true
69
+ _target_: mattergen.common.data.transform.symmetrize_lattice
70
+ - _partial_: true
71
+ _target_: mattergen.common.data.transform.set_chemical_system_string
72
+ transforms:
73
+ - _partial_: true
74
+ _target_: mattergen.common.data.transform.symmetrize_lattice
75
+ - _partial_: true
76
+ _target_: mattergen.common.data.transform.set_chemical_system_string
77
+ val_dataset:
78
+ _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
79
+ cache_path: datasets/cache/alex_mp_20/val
80
+ dataset_transforms:
81
+ - _partial_: true
82
+ _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
83
+ properties:
84
+ - space_group
85
+ transforms:
86
+ - _partial_: true
87
+ _target_: mattergen.common.data.transform.symmetrize_lattice
88
+ - _partial_: true
89
+ _target_: mattergen.common.data.transform.set_chemical_system_string
90
+ lightning_module:
91
+ _target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
92
+ diffusion_module:
93
+ _target_: mattergen.diffusion.diffusion_module.DiffusionModule
94
+ corruption:
95
+ _target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
96
+ discrete_corruptions:
97
+ atomic_numbers:
98
+ _target_: mattergen.diffusion.corruption.d3pm_corruption.D3PMCorruption
99
+ d3pm:
100
+ _target_: mattergen.diffusion.d3pm.d3pm.MaskDiffusion
101
+ dim: 101
102
+ schedule:
103
+ _target_: mattergen.diffusion.d3pm.d3pm.create_discrete_diffusion_schedule
104
+ kind: standard
105
+ num_steps: 1000
106
+ offset: 1
107
+ sdes:
108
+ cell:
109
+ _target_: mattergen.common.diffusion.corruption.LatticeVPSDE.from_vpsde_config
110
+ vpsde_config:
111
+ beta_max: 20
112
+ beta_min: 0.1
113
+ limit_density: 0.05771451654022283
114
+ limit_var_scaling_constant: 0.25
115
+ pos:
116
+ _target_: mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE
117
+ limit_info_key: num_atoms
118
+ sigma_max: 5.0
119
+ wrapping_boundary: 1.0
120
+ loss_fn:
121
+ _target_: mattergen.common.loss.MaterialsLoss
122
+ d3pm_hybrid_lambda: 0.01
123
+ include_atomic_numbers: true
124
+ include_cell: true
125
+ include_pos: true
126
+ reduce: sum
127
+ weights:
128
+ atomic_numbers: 1.0
129
+ cell: 1.0
130
+ pos: 0.1
131
+ model:
132
+ _target_: mattergen.adapter.GemNetTAdapter
133
+ atom_type_diffusion: mask
134
+ denoise_atom_types: true
135
+ gemnet:
136
+ _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
137
+ atom_embedding:
138
+ _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
139
+ emb_size: 512
140
+ with_mask_type: true
141
+ condition_on_adapt:
142
+ - space_group
143
+ cutoff: 7.0
144
+ emb_size_atom: 512
145
+ emb_size_edge: 512
146
+ latent_dim: 512
147
+ max_cell_images_per_dim: 5
148
+ max_neighbors: 50
149
+ num_blocks: 4
150
+ num_targets: 1
151
+ otf_graph: true
152
+ regress_stress: true
153
+ scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
154
+ hidden_dim: 512
155
+ property_embeddings: {}
156
+ property_embeddings_adapt:
157
+ space_group:
158
+ _target_: mattergen.property_embeddings.PropertyEmbedding
159
+ conditional_embedding_module:
160
+ _target_: mattergen.property_embeddings.SpaceGroupEmbeddingVector
161
+ hidden_dim: 512
162
+ name: space_group
163
+ scaler:
164
+ _target_: torch.nn.Identity
165
+ unconditional_embedding_module:
166
+ _target_: mattergen.property_embeddings.EmbeddingVector
167
+ hidden_dim: 512
168
+ pre_corruption_fn:
169
+ _target_: mattergen.property_embeddings.SetEmbeddingType
170
+ dropout_fields_iid: false
171
+ p_unconditional: 0.2
172
+ optimizer_partial:
173
+ _partial_: true
174
+ _target_: torch.optim.Adam
175
+ lr: 5.0e-06
176
+ scheduler_partials:
177
+ - frequency: 1
178
+ interval: epoch
179
+ monitor: loss_train
180
+ scheduler:
181
+ _partial_: true
182
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
183
+ factor: 0.6
184
+ min_lr: 1.0e-06
185
+ patience: 100
186
+ verbose: true
187
+ strict: true
188
+ trainer:
189
+ _target_: pytorch_lightning.Trainer
190
+ accelerator: gpu
191
+ accumulate_grad_batches: 1
192
+ callbacks:
193
+ - _target_: pytorch_lightning.callbacks.LearningRateMonitor
194
+ log_momentum: false
195
+ logging_interval: step
196
+ - _target_: pytorch_lightning.callbacks.ModelCheckpoint
197
+ every_n_epochs: 1
198
+ filename: '{epoch}-{loss_val:.2f}'
199
+ mode: min
200
+ monitor: loss_val
201
+ save_last: true
202
+ save_top_k: 1
203
+ verbose: false
204
+ - _target_: pytorch_lightning.callbacks.TQDMProgressBar
205
+ refresh_rate: 50
206
+ - _target_: mattergen.common.data.callback.SetPropertyScalers
207
+ check_val_every_n_epoch: 1
208
+ devices: 8
209
+ gradient_clip_algorithm: value
210
+ gradient_clip_val: 0.5
211
+ logger:
212
+ _target_: pytorch_lightning.loggers.WandbLogger
213
+ job_type: train_finetune
214
+ project: crystal-generation
215
+ settings:
216
+ _save_requirements: false
217
+ _target_: wandb.Settings
218
+ start_method: fork
219
+ max_epochs: 200
220
+ num_nodes: 1
221
+ precision: 32
222
+ strategy:
223
+ _target_: pytorch_lightning.strategies.ddp.DDPStrategy
224
+ find_unused_parameters: true