Yibin Lei
commited on
Commit
·
30d5d94
1
Parent(s):
34a06fb
Update readme
Browse files
README.md
CHANGED
@@ -8652,4 +8652,83 @@ model-index:
|
|
8652 |
|
8653 |
<h2 align="center"> LENS Embeddings</h2>
|
8654 |
|
8655 |
-
LENS is a model that produces **L**exicon-based **E**mbeddi**N**g**S** (LENS) leveraging large language models. Each dimension of the embeddings is designed to correspond to a token cluster where semantically similar tokens are grouped together. These embeddings have a similar feature size as dense embeddings, with LENS-d8000 offering 8000-dimensional representations.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8652 |
|
8653 |
<h2 align="center"> LENS Embeddings</h2>
|
8654 |
|
8655 |
+
LENS is a model that produces **L**exicon-based **E**mbeddi**N**g**S** (LENS) leveraging large language models. Each dimension of the embeddings is designed to correspond to a token cluster where semantically similar tokens are grouped together. These embeddings have a similar feature size as dense embeddings, with LENS-d8000 offering 8000-dimensional representations.
|
8656 |
+
|
8657 |
+
## Usage
|
8658 |
+
```
|
8659 |
+
git clone https://huggingface.co/yibinlei/LENS-d8000
|
8660 |
+
cd LENS-d8000
|
8661 |
+
```
|
8662 |
+
|
8663 |
+
```python
|
8664 |
+
import torch
|
8665 |
+
from torch import Tensor
|
8666 |
+
import torch.nn.functional as F
|
8667 |
+
from transformers import AutoTokenizer
|
8668 |
+
from bidirectional_mistral import MistralBiForCausalLM
|
8669 |
+
|
8670 |
+
def get_detailed_instruct(task_instruction: str, query: str) -> str:
|
8671 |
+
return f'<instruct>{task_instruction}\n<query>{query}'
|
8672 |
+
|
8673 |
+
def pooling_func(vecs: Tensor, pooling_mask: Tensor) -> Tensor:
|
8674 |
+
# We use max-pooling for LENS.
|
8675 |
+
return torch.max(torch.log(1 + torch.relu(vecs)) * pooling_mask.unsqueeze(-1), dim=1).values
|
8676 |
+
|
8677 |
+
# Prepare the data
|
8678 |
+
instruction = "Given a web search query, retrieve relevant passages that answer the query."
|
8679 |
+
queries = ["what is rba",
|
8680 |
+
"what is oilskin fabric"]
|
8681 |
+
instructed_queries = [get_detailed_instruct(instruction, query) for query in queries]
|
8682 |
+
docs = ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal.",
|
8683 |
+
"Today's oilskins (or oilies) typically come in two parts, jackets and trousers. Oilskin jackets are generally similar to common rubberized waterproofs."]
|
8684 |
+
|
8685 |
+
# Load the model and tokenizer
|
8686 |
+
model = MistralBiForCausalLM.from_pretrained("yibinlei/LENS-d8000")
|
8687 |
+
model.lm_head = torch.load('lm_head.pth')
|
8688 |
+
tokenizer = AutoTokenizer.from_pretrained("yibinlei/LENS-d8000")
|
8689 |
+
|
8690 |
+
# Preprocess the data
|
8691 |
+
query_max_len, doc_max_len = 512, 512
|
8692 |
+
instructed_query_inputs = tokenizer(
|
8693 |
+
instructed_queries,
|
8694 |
+
padding=True,
|
8695 |
+
truncation=True,
|
8696 |
+
return_tensors='pt',
|
8697 |
+
max_length=query_max_len,
|
8698 |
+
add_special_tokens=True
|
8699 |
+
)
|
8700 |
+
doc_inputs = tokenizer(
|
8701 |
+
docs,
|
8702 |
+
padding=True,
|
8703 |
+
truncation=True,
|
8704 |
+
return_tensors='pt',
|
8705 |
+
max_length=doc_max_len,
|
8706 |
+
add_special_tokens=True
|
8707 |
+
)
|
8708 |
+
# We perform pooling exclusively on the outputs of the query tokens, excluding outputs from the instruction.
|
8709 |
+
query_only_mask = torch.zeros_like(instructed_query_inputs['input_ids'], dtype=instructed_query_inputs['attention_mask'].dtype)
|
8710 |
+
special_token_id = tokenizer.convert_tokens_to_ids('<query>')
|
8711 |
+
for idx, seq in enumerate(instructed_query_inputs['input_ids']):
|
8712 |
+
special_pos = (seq == special_token_id).nonzero()
|
8713 |
+
if len(special_pos) > 0:
|
8714 |
+
query_start_pos = special_pos[-1].item()
|
8715 |
+
query_only_mask[idx, query_start_pos:-2] = 1
|
8716 |
+
else:
|
8717 |
+
raise ValueError("No special token found")
|
8718 |
+
|
8719 |
+
# Obtain the embeddings
|
8720 |
+
with torch.no_grad():
|
8721 |
+
instructed_query_outputs = model(**instructed_query_inputs)
|
8722 |
+
query_embeddings = pooling_func(instructed_query_outputs, query_only_mask)
|
8723 |
+
doc_outputs = model(**doc_inputs)
|
8724 |
+
# As the output of each token is used for predicting the next token, the pooling mask is shifted left by 1. The output of the final token EOS token is also excluded.
|
8725 |
+
doc_inputs['attention_mask'][:, -2:] = 0
|
8726 |
+
doc_embeddings = pooling_func(doc_outputs, doc_inputs['attention_mask'])
|
8727 |
+
|
8728 |
+
# Normalize the embeddings
|
8729 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
8730 |
+
doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
|
8731 |
+
|
8732 |
+
# Compute the similarity
|
8733 |
+
similarity = torch.matmul(query_embeddings, doc_embeddings.T)
|
8734 |
+
```
|