File size: 4,985 Bytes
cd9a0ff
5032567
 
b9e2c6a
 
44703cb
550983f
 
82507d9
 
 
cd9a0ff
5032567
7f41f16
b1dd6ef
 
5032567
 
 
 
 
753c881
81883d8
5032567
 
 
 
b9e2c6a
 
 
5032567
 
 
 
753c881
5032567
 
 
 
 
 
52e8fb9
81883d8
52e8fb9
81883d8
52e8fb9
 
 
81883d8
 
 
 
 
 
 
52e8fb9
81883d8
 
 
 
0fe5265
 
81883d8
 
52e8fb9
 
81883d8
52e8fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fe5265
52e8fb9
 
5032567
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
---
language:
- en
tags:
- feature-extraction
- sentence-similarity
datasets:
- biu-nlp/abstract-sim
widgets:
- sentence-similarity
- feature-extraction
---

A model for mapping abstract sentence descriptions to sentences that fit the descriptions. Trained on Wikipedia. Use ```load_finetuned_model``` to load the query and sentence encoder, and ```encode_batch()``` to encode a sentence with the model.

**Note**: the method uses a dual encoder architecture. This is the **sentence encoder**; it should be used alongside the [**Query encoder**](https://huggingface.co/biu-nlp/abstract-sim-query).

```python

from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
from sklearn.metrics.pairwise import cosine_similarity

def load_finetuned_model():


        sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence")
        query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query")
        tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence")

        return tokenizer, query_encoder, sentence_encoder


def encode_batch(model, tokenizer, sentences: List[str], device: str):
    input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt",
                          add_special_tokens=True).to(device)
    features = model(**input_ids)[0]
    features =  torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
    return features

```

Usage example:

```python
tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
relevant_sentences = ["Fingersoft's parent company is the Finger Group.",
                      "WHIRC – a subsidiary company of Wright-Hennepin",
                      "CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings",
                      "EM Microelectronic-Marin (subsidiary of The Swatch Group).",
                      "The company is currently a division of the corporate group Jam Industries.",
                      "Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.)."
             ]

irrelevant_sentences = ["The second company is deemed to be a subsidiary of the parent company.",
                        "The company has gone through more than one incarnation.",
                        "The company is owned by its employees.",
                        "Larger companies compete for market share by acquiring smaller companies that may own a particular market sector.",
                        "A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary).",
                        "It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power.",
                        "RXVT Technologies is no longer a subsidiary of the parent company."
                        ]

all_sentences = relevant_sentences + irrelevant_sentences
query = "<query>: A company is a part of a larger company."
    
embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy()
query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy()

sims = cosine_similarity(query_embedding, embeddings)[0]
sentences_sims = list(zip(all_sentences, sims))
sentences_sims.sort(key=lambda x: x[1], reverse=True)

for s, sim in sentences_sims:
    print(s, sim)

```

Expected output:

```
WHIRC – a subsidiary company of Wright-Hennepin 0.9396286
EM Microelectronic-Marin (subsidiary of The Swatch Group). 0.93929046
Fingersoft's parent company is the Finger Group. 0.936247
CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings 0.9350312
The company is currently a division of the corporate group Jam Industries. 0.9273489
Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.). 0.9005086
The second company is deemed to be a subsidiary of the parent company. 0.6723645
It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power. 0.60081375
A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary). 0.59490484
The company is owned by its employees. 0.55286574
RXVT Technologies is no longer a subsidiary of the parent company. 0.4321953
The company has gone through more than one incarnation. 0.38889483
Larger companies compete for market share by acquiring smaller companies that may own a particular market sector. 0.25472647
```