Yibin Lei commited on
Commit
30d5d94
·
1 Parent(s): 34a06fb

Update readme

Browse files
Files changed (1) hide show
  1. README.md +80 -1
README.md CHANGED
@@ -8652,4 +8652,83 @@ model-index:
8652
 
8653
  <h2 align="center"> LENS Embeddings</h2>
8654
 
8655
- LENS is a model that produces **L**exicon-based **E**mbeddi**N**g**S** (LENS) leveraging large language models. Each dimension of the embeddings is designed to correspond to a token cluster where semantically similar tokens are grouped together. These embeddings have a similar feature size as dense embeddings, with LENS-d8000 offering 8000-dimensional representations.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8652
 
8653
  <h2 align="center"> LENS Embeddings</h2>
8654
 
8655
+ LENS is a model that produces **L**exicon-based **E**mbeddi**N**g**S** (LENS) leveraging large language models. Each dimension of the embeddings is designed to correspond to a token cluster where semantically similar tokens are grouped together. These embeddings have a similar feature size as dense embeddings, with LENS-d8000 offering 8000-dimensional representations.
8656
+
8657
+ ## Usage
8658
+ ```
8659
+ git clone https://huggingface.co/yibinlei/LENS-d8000
8660
+ cd LENS-d8000
8661
+ ```
8662
+
8663
+ ```python
8664
+ import torch
8665
+ from torch import Tensor
8666
+ import torch.nn.functional as F
8667
+ from transformers import AutoTokenizer
8668
+ from bidirectional_mistral import MistralBiForCausalLM
8669
+
8670
+ def get_detailed_instruct(task_instruction: str, query: str) -> str:
8671
+ return f'<instruct>{task_instruction}\n<query>{query}'
8672
+
8673
+ def pooling_func(vecs: Tensor, pooling_mask: Tensor) -> Tensor:
8674
+ # We use max-pooling for LENS.
8675
+ return torch.max(torch.log(1 + torch.relu(vecs)) * pooling_mask.unsqueeze(-1), dim=1).values
8676
+
8677
+ # Prepare the data
8678
+ instruction = "Given a web search query, retrieve relevant passages that answer the query."
8679
+ queries = ["what is rba",
8680
+ "what is oilskin fabric"]
8681
+ instructed_queries = [get_detailed_instruct(instruction, query) for query in queries]
8682
+ docs = ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal.",
8683
+ "Today's oilskins (or oilies) typically come in two parts, jackets and trousers. Oilskin jackets are generally similar to common rubberized waterproofs."]
8684
+
8685
+ # Load the model and tokenizer
8686
+ model = MistralBiForCausalLM.from_pretrained("yibinlei/LENS-d8000")
8687
+ model.lm_head = torch.load('lm_head.pth')
8688
+ tokenizer = AutoTokenizer.from_pretrained("yibinlei/LENS-d8000")
8689
+
8690
+ # Preprocess the data
8691
+ query_max_len, doc_max_len = 512, 512
8692
+ instructed_query_inputs = tokenizer(
8693
+ instructed_queries,
8694
+ padding=True,
8695
+ truncation=True,
8696
+ return_tensors='pt',
8697
+ max_length=query_max_len,
8698
+ add_special_tokens=True
8699
+ )
8700
+ doc_inputs = tokenizer(
8701
+ docs,
8702
+ padding=True,
8703
+ truncation=True,
8704
+ return_tensors='pt',
8705
+ max_length=doc_max_len,
8706
+ add_special_tokens=True
8707
+ )
8708
+ # We perform pooling exclusively on the outputs of the query tokens, excluding outputs from the instruction.
8709
+ query_only_mask = torch.zeros_like(instructed_query_inputs['input_ids'], dtype=instructed_query_inputs['attention_mask'].dtype)
8710
+ special_token_id = tokenizer.convert_tokens_to_ids('<query>')
8711
+ for idx, seq in enumerate(instructed_query_inputs['input_ids']):
8712
+ special_pos = (seq == special_token_id).nonzero()
8713
+ if len(special_pos) > 0:
8714
+ query_start_pos = special_pos[-1].item()
8715
+ query_only_mask[idx, query_start_pos:-2] = 1
8716
+ else:
8717
+ raise ValueError("No special token found")
8718
+
8719
+ # Obtain the embeddings
8720
+ with torch.no_grad():
8721
+ instructed_query_outputs = model(**instructed_query_inputs)
8722
+ query_embeddings = pooling_func(instructed_query_outputs, query_only_mask)
8723
+ doc_outputs = model(**doc_inputs)
8724
+ # As the output of each token is used for predicting the next token, the pooling mask is shifted left by 1. The output of the final token EOS token is also excluded.
8725
+ doc_inputs['attention_mask'][:, -2:] = 0
8726
+ doc_embeddings = pooling_func(doc_outputs, doc_inputs['attention_mask'])
8727
+
8728
+ # Normalize the embeddings
8729
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
8730
+ doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
8731
+
8732
+ # Compute the similarity
8733
+ similarity = torch.matmul(query_embeddings, doc_embeddings.T)
8734
+ ```