Update README.md
Browse files
README.md
CHANGED
@@ -40,7 +40,21 @@ model_output = model(**encoded_input)
|
|
40 |
|
41 |
Sequence embeddings can be produced as follows:
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
### Fine-tune
|
46 |
|
|
|
40 |
|
41 |
Sequence embeddings can be produced as follows:
|
42 |
|
43 |
+
```python
|
44 |
+
def sequence_embeddings(encoded_input, model_output):
|
45 |
+
mask = encoded_input['attention_mask'].float()
|
46 |
+
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
|
47 |
+
# make sep token invisible
|
48 |
+
for i in d:
|
49 |
+
mask[i, d[i]] = 0
|
50 |
+
mask[:, 0] = 0.0 # make cls token invisible
|
51 |
+
mask = mask.unsqueeze(-1).expand(model_output.last_hidden_state.size())
|
52 |
+
sum_embeddings = torch.sum(model_output.last_hidden_state * mask, 1)
|
53 |
+
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
54 |
+
return sum_embeddings / sum_mask
|
55 |
+
|
56 |
+
seq_embeds = sequence_embeddings(encoded_input, model_output)
|
57 |
+
```
|
58 |
|
59 |
### Fine-tune
|
60 |
|