Jacob Gershon commited on
Commit
29f8f7e
·
1 Parent(s): 59a9ccf

fixed but in partial seq diffusion

Browse files
Files changed (1) hide show
  1. utils/sampler.py +7 -1
utils/sampler.py CHANGED
@@ -291,7 +291,13 @@ class SEQDIFF_sampler:
291
  self.features['xyz_t'] = torch.full((1,1,len(self.args['sequence']),27,3), np.nan)
292
 
293
  self.features['mask_str'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
294
- self.features['mask_seq'] = torch.tensor([0 if x == 'X' else 1 for x in self.args['sequence']]).long()[None,:].bool()
 
 
 
 
 
 
295
  self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool()
296
 
297
  self.features['idx_pdb'] = torch.tensor([i for i in range(len(self.args['sequence']))])[None,:]
 
291
  self.features['xyz_t'] = torch.full((1,1,len(self.args['sequence']),27,3), np.nan)
292
 
293
  self.features['mask_str'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
294
+
295
+ #added check for if in partial diffusion mode will mask
296
+ if self.args['sampling_temp'] == 1.0:
297
+ self.features['mask_seq'] = torch.tensor([0 if x == 'X' else 1 for x in self.args['sequence']]).long()[None,:].bool()
298
+ else:
299
+ self.features['mask_seq'] = torch.zeros(len(self.args['sequence'])).long()[None,:].bool()
300
+
301
  self.features['blank_mask'] = torch.ones(self.features['mask_str'].size()[-1])[None,:].bool()
302
 
303
  self.features['idx_pdb'] = torch.tensor([i for i in range(len(self.args['sequence']))])[None,:]