ravfogs commited on
Commit
81883d8
·
1 Parent(s): 753c881

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -18,6 +18,7 @@ A model for mapping abstract sentence descriptions to sentences that fit the des
18
  from transformers import AutoTokenizer, AutoModel
19
  import torch
20
  from typing import List
 
21
 
22
  def load_finetuned_model():
23
 
@@ -36,4 +37,40 @@ def encode_batch(model, tokenizer, sentences: List[str], device: str):
36
  features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
37
  return features
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ```
 
18
  from transformers import AutoTokenizer, AutoModel
19
  import torch
20
  from typing import List
21
+ from sklearn.metrics.pairwise import cosine_similarity
22
 
23
  def load_finetuned_model():
24
 
 
37
  features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
38
  return features
39
 
40
+
41
+
42
+ if __name__ == "__main__":
43
+
44
+ tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
45
+ relevant_sentences = ["Fingersoft's parent company is the Finger Group.",
46
+ "WHIRC – a subsidiary company of Wright-Hennepin",
47
+ "CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings",
48
+ "EM Microelectronic-Marin (subsidiary of The Swatch Group).",
49
+ "The company is currently a division of the corporate group Jam Industries.",
50
+ "Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.)."
51
+ ]
52
+
53
+ irrelevant_sentences = ["The second company is deemed to be a subsidiary of the parent company.",
54
+ "The company has gone through more than one incarnation.",
55
+ "The company is owned by its employees.",
56
+ "Larger companies compete for market share by acquiring smaller companies that may own a particular market sector.",
57
+ "A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary).",
58
+ "It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power."
59
+ ]
60
+
61
+ all_sentences = relevant_sentences + irrelevant_sentences
62
+ query = "<query>: A company is a part of a larger company."
63
+
64
+ embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy()
65
+ query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy()
66
+
67
+ sims = cosine_similarity(query_embedding, embeddings)[0]
68
+ sentences_sims = list(zip(all_sentences, sims))
69
+ sentences_sims.sort(key=lambda x: x[1], reverse=True)
70
+
71
+ for s, sim in sentences_sims:
72
+ print(s, sim)
73
+
74
+
75
+
76
  ```