martirossyan commited on
Commit
ff74679
·
verified ·
1 Parent(s): f65b5b7

Upload 2 files

Browse files
Files changed (2) hide show
  1. VESBD-ODE/checkpoint.ckpt +3 -0
  2. VESBD-ODE/train.yaml +154 -0
VESBD-ODE/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f16700a860c1e88d117a64c1a41396027927a7f45185fbb5eaf2d4ce43317b8e
3
+ size 148095124
VESBD-ODE/train.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ si:
3
+ class_path: omg.si.stochastic_interpolants.StochasticInterpolants
4
+ init_args:
5
+ stochastic_interpolants:
6
+ # chemical species
7
+ - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
8
+ # fractional coordinates
9
+ - class_path: omg.si.single_stochastic_interpolant_os.SingleStochasticInterpolantOS
10
+ init_args:
11
+ interpolant:
12
+ class_path: omg.si.interpolants.PeriodicScoreBasedDiffusionModelInterpolantVE
13
+ init_args:
14
+ sigma:
15
+ class_path: omg.si.sigma.GeometricSigma
16
+ init_args:
17
+ sigma_min: 0.007753186833706728
18
+ sigma_max: 0.5165059747015202
19
+ epsilon: null
20
+ differential_equation_type: "ODE"
21
+ integrator_kwargs:
22
+ method: "euler"
23
+ velocity_annealing_factor: 0.0030999124784898413
24
+ correct_center_of_mass_motion: true
25
+ predict_velocity: true
26
+ # lattice vectors
27
+ - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
28
+ init_args:
29
+ interpolant: omg.si.interpolants.TrigonometricInterpolant
30
+ gamma:
31
+ class_path: omg.si.gamma.LatentGammaSqrt
32
+ init_args:
33
+ a: 0.024482789522429726
34
+ epsilon:
35
+ class_path: omg.si.epsilon.VanishingEpsilon
36
+ init_args:
37
+ c: 9.940425570212101
38
+ mu: 0.24041599621265147
39
+ sigma: 0.021132860336543085
40
+ differential_equation_type: "SDE"
41
+ integrator_kwargs:
42
+ method: "euler"
43
+ dt: 0.0026332451961934566
44
+ velocity_annealing_factor: 14.933642154361792
45
+ correct_center_of_mass_motion: false
46
+ data_fields:
47
+ # if the order of the data_fields changes,
48
+ # the order of the above StochasticInterpolant inputs must also change
49
+ - "species"
50
+ - "pos"
51
+ - "cell"
52
+ integration_time_steps: 380
53
+ relative_si_costs:
54
+ species_loss: 0.0
55
+ pos_loss_b: 0.979954187812053
56
+ cell_loss_b: 0.01866918394074503
57
+ cell_loss_z: 0.0013766282472020075
58
+ sampler:
59
+ class_path: omg.sampler.sample_from_rng.SampleFromRNG
60
+ init_args:
61
+ pos_distribution:
62
+ class_path: omg.sampler.distributions.NormalDistribution
63
+ init_args:
64
+ scale: 8.955438982782663
65
+ cell_distribution:
66
+ class_path: omg.sampler.distributions.InformedLatticeDistribution
67
+ init_args:
68
+ dataset_name: perov_5
69
+ species_distribution:
70
+ class_path: omg.sampler.distributions.MirrorData
71
+ model:
72
+ class_path: omg.model.model.Model
73
+ init_args:
74
+ encoder:
75
+ class_path: omg.model.encoders.cspnet_full.CSPNetFull
76
+ head:
77
+ class_path: omg.model.heads.pass_through.PassThrough
78
+ time_embedder:
79
+ class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
80
+ init_args:
81
+ dim: 256
82
+ use_min_perm_dist: False
83
+ float_32_matmul_precision: "high"
84
+ validation_mode: "match_rate"
85
+ dataset_name: "perov_5"
86
+ data:
87
+ train_dataset:
88
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
89
+ init_args:
90
+ dataset:
91
+ class_path: omg.datamodule.datamodule.DataModule
92
+ init_args:
93
+ lmdb_paths:
94
+ - "data/perov_5/train.lmdb"
95
+ niggli: False
96
+ val_dataset:
97
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
98
+ init_args:
99
+ dataset:
100
+ class_path: omg.datamodule.datamodule.DataModule
101
+ init_args:
102
+ lmdb_paths:
103
+ - "data/perov_5/val.lmdb"
104
+ niggli: False
105
+ predict_dataset:
106
+ class_path: omg.datamodule.dataloader.OMGTorchDataset
107
+ init_args:
108
+ dataset:
109
+ class_path: omg.datamodule.datamodule.DataModule
110
+ init_args:
111
+ lmdb_paths:
112
+ - "data/perov_5/test.lmdb"
113
+ niggli: False
114
+ batch_size: 256
115
+ num_workers: 4
116
+ pin_memory: True
117
+ persistent_workers: True
118
+ trainer:
119
+ callbacks:
120
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
121
+ init_args:
122
+ filename: "best_val_loss_total"
123
+ save_top_k: 1
124
+ monitor: "val_loss_total"
125
+ save_weights_only: true
126
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
127
+ init_args:
128
+ filename: "best_val_match_rate"
129
+ save_top_k: 1
130
+ monitor: "match_rate"
131
+ save_weights_only: true
132
+ mode: 'max'
133
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
134
+ init_args:
135
+ filename: "best_val_rmsd"
136
+ save_top_k: 1
137
+ monitor: "mean_rmsd"
138
+ save_weights_only: true
139
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
140
+ init_args:
141
+ save_top_k: -1 # Store every checkpoint after 100 epochs.
142
+ monitor: "val_loss_total"
143
+ every_n_epochs: 100
144
+ save_weights_only: false
145
+ gradient_clip_val: 0.5
146
+ num_sanity_val_steps: 0
147
+ precision: "32-true"
148
+ max_epochs: 6000
149
+ enable_progress_bar: false
150
+ check_val_every_n_epoch: 100
151
+ optimizer:
152
+ class_path: torch.optim.Adam
153
+ init_args:
154
+ lr: 0.0077762908469486665