--- language: - en tags: - feature-extraction - sentence-similarity datasets: - biu-nlp/abstract-sim widgets: - sentence-similarity - feature-extraction --- A model for mapping abstract sentence descriptions to sentences that fit the descriptions. Trained on Wikipedia. Use ```load_finetuned_model``` to load the query and sentence encoder, and ```encode_batch()``` to encode a sentence with the model. ```python from transformers import AutoTokenizer, AutoModel import torch from typing import List from sklearn.metrics.pairwise import cosine_similarity def load_finetuned_model(): sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence") query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query") tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence") return tokenizer, query_encoder, sentence_encoder def encode_batch(model, tokenizer, sentences: List[str], device: str): input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt", add_special_tokens=True).to(device) features = model(**input_ids)[0] features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9) return features if __name__ == "__main__": tokenizer, query_encoder, sentence_encoder = load_finetuned_model() relevant_sentences = ["Fingersoft's parent company is the Finger Group.", "WHIRC – a subsidiary company of Wright-Hennepin", "CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings", "EM Microelectronic-Marin (subsidiary of The Swatch Group).", "The company is currently a division of the corporate group Jam Industries.", "Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.)." ] irrelevant_sentences = ["The second company is deemed to be a subsidiary of the parent company.", "The company has gone through more than one incarnation.", "The company is owned by its employees.", "Larger companies compete for market share by acquiring smaller companies that may own a particular market sector.", "A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary).", "It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power." ] all_sentences = relevant_sentences + irrelevant_sentences query = ": A company is a part of a larger company." embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy() query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy() sims = cosine_similarity(query_embedding, embeddings)[0] sentences_sims = list(zip(all_sentences, sims)) sentences_sims.sort(key=lambda x: x[1], reverse=True) for s, sim in sentences_sims: print(s, sim) ```