qilowoq commited on
Commit
9edddc7
·
1 Parent(s): 8cd5495

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -1
README.md CHANGED
@@ -40,7 +40,21 @@ model_output = model(**encoded_input)
40
 
41
  Sequence embeddings can be produced as follows:
42
 
43
- TBA (just mean pool not including special tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  ### Fine-tune
46
 
 
40
 
41
  Sequence embeddings can be produced as follows:
42
 
43
+ ```python
44
+ def sequence_embeddings(encoded_input, model_output):
45
+ mask = encoded_input['attention_mask'].float()
46
+ d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
47
+ # make sep token invisible
48
+ for i in d:
49
+ mask[i, d[i]] = 0
50
+ mask[:, 0] = 0.0 # make cls token invisible
51
+ mask = mask.unsqueeze(-1).expand(model_output.last_hidden_state.size())
52
+ sum_embeddings = torch.sum(model_output.last_hidden_state * mask, 1)
53
+ sum_mask = torch.clamp(mask.sum(1), min=1e-9)
54
+ return sum_embeddings / sum_mask
55
+
56
+ seq_embeds = sequence_embeddings(encoded_input, model_output)
57
+ ```
58
 
59
  ### Fine-tune
60