|
import string |
|
from typing import Optional, Union, Tuple, List |
|
from dataclasses import dataclass |
|
from tqdm import tqdm |
|
import warnings |
|
import nltk |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoTokenizer |
|
from transformers import DebertaV2PreTrainedModel, DebertaV2Model, PretrainedConfig |
|
try: |
|
from transformers.models.deberta_v2.modeling_deberta_v2 import ( |
|
StableDropout, |
|
ContextPooler, |
|
) |
|
except ImportError: |
|
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler |
|
StableDropout = nn.Dropout |
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
|
|
@dataclass |
|
class RankingCompressionOutput(ModelOutput): |
|
|
|
compression_logits: torch.FloatTensor = None |
|
ranking_scores: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
"""adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L1357 |
|
""" |
|
|
|
|
|
class ProvenceConfig(PretrainedConfig): |
|
|
|
model_type = "Provence" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class Provence(DebertaV2PreTrainedModel): |
|
|
|
config_class = ProvenceConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
num_labels = getattr(config, "num_labels", 2) |
|
self.num_labels = num_labels |
|
self.deberta = DebertaV2Model(config) |
|
self.pooler = ContextPooler(config) |
|
output_dim = self.pooler.output_dim |
|
|
|
|
|
self.classifier = nn.Linear(output_dim, num_labels) |
|
drop_out = getattr(config, "cls_dropout", None) |
|
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out |
|
self.dropout = StableDropout(drop_out) |
|
|
|
|
|
token_dropout = drop_out |
|
self.token_dropout = nn.Dropout(token_dropout) |
|
self.token_classifier = nn.Linear( |
|
config.hidden_size, 2 |
|
) |
|
self.name = "Provence" |
|
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
|
self.max_len = config.max_position_embeddings |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
) -> RankingCompressionOutput: |
|
outputs = self.deberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
encoder_layer = outputs[0] |
|
pooled_output = self.pooler(encoder_layer) |
|
pooled_output = self.dropout(pooled_output) |
|
ranking_logits = self.classifier(pooled_output) |
|
compression_logits = self.token_classifier(self.token_dropout(encoder_layer)) |
|
ranking_scores = ranking_logits[ |
|
:, 0 |
|
].squeeze() |
|
|
|
return RankingCompressionOutput( |
|
compression_logits=compression_logits, |
|
ranking_scores=ranking_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def process( |
|
self, |
|
question: Union[List[str], str], |
|
context: Union[List[List[str]], str], |
|
title: Optional[Union[List[List[str]], str]] = "first_sentence", |
|
batch_size=32, |
|
threshold=0.1, |
|
always_select_title=False, |
|
reorder=False, |
|
top_k=5, |
|
enable_warnings=True, |
|
): |
|
|
|
|
|
if type(question) == str: |
|
queries = [question] |
|
else: |
|
queries = question |
|
if type(context) == str: |
|
contexts = [[context]] |
|
else: |
|
contexts = context |
|
if type(title) == str and title != "first_sentence": |
|
titles = [[title]] |
|
else: |
|
titles = title |
|
assert ( |
|
titles == "first_sentence" |
|
or titles == None |
|
or type(titles) == list |
|
and len(titles) == len(queries) |
|
), "Variable 'titles' must be 'first_sentence' or a list of strings of the same length as 'queries'" |
|
if type(titles) == list: |
|
assert all( |
|
[ |
|
len(titles_item) == len(contexts_item) |
|
for titles_item, contexts_item in zip(contexts, titles) |
|
] |
|
), "Each list in 'titles' must have the same length as the corresponding list in 'context'" |
|
assert len(queries) == len( |
|
contexts |
|
), "Lists 'queries' and 'contexts' must have same lengths" |
|
dataset = TestDataset( |
|
queries=queries, |
|
contexts=contexts, |
|
titles=titles, |
|
tokenizer=self.tokenizer, |
|
max_len=self.max_len, |
|
enable_warnings=enable_warnings, |
|
) |
|
selected_contexts = [ |
|
[{0: contexts[i][j]} for j in range(len(contexts[i]))] |
|
for i in range(len(queries)) |
|
] |
|
reranking_scores = [ |
|
[None for j in range(len(contexts[i]))] for i in range(len(queries)) |
|
] |
|
with torch.no_grad(): |
|
for batch_start in tqdm( |
|
range(0, len(dataset), batch_size), desc="Pruning contexts..." |
|
): |
|
qis = dataset.qis[batch_start : batch_start + batch_size] |
|
cis = dataset.cis[batch_start : batch_start + batch_size] |
|
sis = dataset.sis[batch_start : batch_start + batch_size] |
|
sent_coords = dataset.sent_coords[ |
|
batch_start : batch_start + batch_size |
|
] |
|
ids_list = dataset.ids[batch_start : batch_start + batch_size] |
|
ids = pad_sequence( |
|
ids_list, batch_first=True, padding_value=dataset.pad_idx |
|
).to(self.device) |
|
mask = (ids != dataset.pad_idx).to(self.device) |
|
outputs = self.forward(ids, mask) |
|
scores = F.softmax(outputs["compression_logits"].cpu(), dim=-1)[:, :, 1] |
|
token_preds = scores > threshold |
|
reranking_scrs = ( |
|
outputs["ranking_scores"].cpu().numpy() |
|
) |
|
if len(reranking_scrs.shape) == 0: |
|
reranking_scrs = reranking_scrs[None] |
|
for ( |
|
ids_list_, |
|
token_preds_, |
|
rerank_score, |
|
qi, |
|
ci, |
|
si, |
|
sent_coords_, |
|
) in zip( |
|
ids_list, token_preds, reranking_scrs, qis, cis, sis, sent_coords |
|
): |
|
|
|
selected_mask = sentence_rounding( |
|
token_preds_.cpu().numpy(), |
|
np.array(sent_coords_), |
|
threshold=threshold, |
|
always_select_title=always_select_title |
|
and si == 0 |
|
and titles != None, |
|
) |
|
assert len(selected_mask) == len(token_preds_) |
|
selected_contexts[qi][ci][si] = ids_list_[ |
|
selected_mask[: len(ids_list_)] |
|
] |
|
if si == 0: |
|
reranking_scores[qi][ci] = rerank_score |
|
for i in range(len(queries)): |
|
for j in range(len(contexts[i])): |
|
if type(selected_contexts[i][j][0]) != str: |
|
toks = torch.cat( |
|
[ |
|
ids_ |
|
for _, ids_ in sorted( |
|
selected_contexts[i][j].items(), key=lambda x: x[0] |
|
) |
|
] |
|
) |
|
selected_contexts[i][j] = self.tokenizer.decode( |
|
toks, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False, |
|
) |
|
else: |
|
selected_contexts[i][j] = selected_contexts[i][j][0] |
|
if reorder: |
|
print(reranking_scores[qi]) |
|
print(np.sort(reranking_scores[i])[::-1][:top_k]) |
|
idxs = np.argsort(reranking_scores[i])[::-1][:top_k] |
|
selected_contexts[i] = [selected_contexts[i][j] for j in idxs] |
|
reranking_scores[i] = [reranking_scores[i][j] for j in idxs] |
|
|
|
if type(context) == str: |
|
selected_contexts = selected_contexts[0][0] |
|
reranking_scores = reranking_scores[0][0] |
|
|
|
return { |
|
"pruned_context": selected_contexts, |
|
"reranking_score": reranking_scores |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def sentence_rounding(predictions, chunks, threshold, always_select_title=True): |
|
""" |
|
predictions: a binary vector containing 1 for tokens which were selected and 0s otherwise |
|
chunks: a list of pairs [start, end] of sentence, i.e. sentence is in coordinates predictions[start:end] |
|
the functions |
|
""" |
|
cumulative_sum = np.cumsum(predictions) |
|
chunk_sums = cumulative_sum[chunks[:, 1] - 1] - np.where( |
|
chunks[:, 0] > 0, cumulative_sum[chunks[:, 0] - 1], 0 |
|
) |
|
chunk_lengths = chunks[:, 1] - chunks[:, 0] |
|
chunk_means = chunk_sums / chunk_lengths |
|
if always_select_title and (chunk_means>threshold).any(): |
|
chunk_means[0] = 1 |
|
means = np.hstack((np.zeros(1), chunk_means, np.zeros(1))) |
|
repeats = np.hstack( |
|
([chunks[0][0]], chunk_lengths, [predictions.shape[0] - chunks[-1][1]]) |
|
) |
|
return np.repeat(means, repeats) > threshold |
|
|
|
|
|
def normalize(s: str) -> str: |
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_punc(lower(s))) |
|
|
|
|
|
def sent_split_and_tokenize(text, tokenizer, max_len): |
|
sents_nltk = nltk.sent_tokenize(text) |
|
sents = [] |
|
for j, sent_nltk in enumerate(sents_nltk): |
|
tokinput = (" " if j != 0 else "") + sent_nltk |
|
tok = tokenizer.encode(tokinput, add_special_tokens=False) |
|
ltok = len(tok) |
|
if ltok == 0: |
|
continue |
|
if ltok <= max_len: |
|
sents.append(tok) |
|
else: |
|
for begin in range(0, ltok, max_len): |
|
sents.append(tok[begin : begin + max_len]) |
|
return sents |
|
|
|
|
|
class TestDataset(Dataset): |
|
def __init__( |
|
self, |
|
queries, |
|
contexts, |
|
tokenizer, |
|
max_len=512, |
|
titles="first_sentence", |
|
enable_warnings=True, |
|
): |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
self.pad_idx = 0 |
|
self.cls_idx = [1] |
|
self.sep_idx = [2] |
|
self.eos = [2] |
|
|
|
self.nb_spe_tok = len(self.cls_idx) + len(self.sep_idx) |
|
self.enable_warnings = enable_warnings |
|
self.unusual_query_length = ( |
|
self.max_len // 2 |
|
) |
|
self.unusual_title_len = self.max_len // 2 |
|
self.create_dataset(contexts, queries, titles) |
|
self.len = len(self.cis) |
|
|
|
def create_dataset(self, contexts, queries, titles="first_sentence"): |
|
self.qis = [] |
|
self.cis = [] |
|
self.sis = [] |
|
self.sent_coords = [] |
|
self.cntx_coords = [] |
|
self.ids = [] |
|
if self.enable_warnings: |
|
warnings_dict = { |
|
"zero_len_query": set(), |
|
"too_long_query": set(), |
|
"unusually_long_query": set(), |
|
"unusually_long_title": set(), |
|
"split_context": set(), |
|
} |
|
for i, query in enumerate(queries): |
|
tokenized_query = self.tokenizer.encode( |
|
normalize(query), add_special_tokens=False |
|
) |
|
|
|
query_len = len(tokenized_query) |
|
if query_len == 0: |
|
if self.enable_warnings: |
|
warnings_dict["zero_len_query"].add(i) |
|
continue |
|
elif query_len >= self.max_len - self.nb_spe_tok - 1: |
|
if self.enable_warnings: |
|
warnings_dict["too_long_query"].add(i) |
|
continue |
|
elif query_len >= self.unusual_query_length: |
|
if self.enable_warnings: |
|
warnings_dict["unusually_long_query"].add(i) |
|
left_0 = len(tokenized_query) + self.nb_spe_tok |
|
tokenized_seq_0 = self.cls_idx + tokenized_query + self.sep_idx |
|
max_len = self.max_len - left_0 - 1 |
|
for j, cntx in enumerate(contexts[i]): |
|
title = titles[i][j] if type(titles) == list else titles |
|
tokenized_sents = sent_split_and_tokenize(cntx, self.tokenizer, max_len) |
|
|
|
if title is not None and title != "first_sentence": |
|
tokenized_title = self.tokenizer.encode( |
|
title, add_special_tokens=False |
|
) |
|
ltok = len(tokenized_title) |
|
if ltok == 0: |
|
pass |
|
elif ltok <= max_len: |
|
tokenized_sents = [tokenized_title] + tokenized_sents |
|
else: |
|
if self.enable_warnings and ltok >= self.unusual_title_len: |
|
warnings_dict["unusually_long_title"].add(i) |
|
tokenized_sents = [ |
|
tokenized_title[begin : begin + max_len] |
|
for begin in range(0, ltok, max_len) |
|
] + tokenized_sents |
|
tokenized_seq = tokenized_seq_0 |
|
left = left_0 |
|
sent_coords = [] |
|
block = 0 |
|
for idx, tokenized_sent in enumerate(tokenized_sents): |
|
l = len(tokenized_sent) |
|
if left + l <= self.max_len - 1: |
|
sent_coords.append([left, left + l]) |
|
tokenized_seq = tokenized_seq + tokenized_sent |
|
left += l |
|
else: |
|
if self.enable_warnings: |
|
warnings_dict["split_context"].add(i) |
|
if len(tokenized_seq) > left_0: |
|
tokenized_seq = tokenized_seq + self.eos |
|
self.qis.append(i) |
|
self.cis.append(j) |
|
self.sis.append(block) |
|
self.sent_coords.append(sent_coords) |
|
self.cntx_coords.append( |
|
[sent_coords[0][0], sent_coords[-1][1]] |
|
) |
|
self.ids.append(torch.tensor(tokenized_seq)) |
|
tokenized_seq = tokenized_seq_0 + tokenized_sent |
|
sent_coords = [[left_0, left_0 + l]] |
|
left = left_0 + l |
|
block += 1 |
|
if len(tokenized_seq) > left_0: |
|
tokenized_seq = tokenized_seq + self.eos |
|
self.qis.append(i) |
|
self.cis.append(j) |
|
self.sis.append(block) |
|
self.sent_coords.append(sent_coords) |
|
self.cntx_coords.append([sent_coords[0][0], sent_coords[-1][1]]) |
|
self.ids.append(torch.tensor(tokenized_seq)) |
|
if self.enable_warnings: |
|
self.print_warnings(warnings_dict, len(queries)) |
|
|
|
def __len__(self): |
|
return len(self.ids) |
|
|
|
def print_warnings(self, warnings_dict, N): |
|
n = len(warnings_dict["zero_len_query"]) |
|
info = " You can suppress Provence warnings by setting enable_warnings=False." |
|
if n > 0: |
|
ex = list(warnings_dict["zero_len_query"])[:10] |
|
warnings.warn( |
|
f"{n} out of {N} queries have zero length, e.g. at indexes {ex}. " |
|
"These examples will be skipped in context pruning, " |
|
"their contexts will be kept as is." + info |
|
) |
|
n = len(warnings_dict["too_long_query"]) |
|
if n > 0: |
|
ex = list(warnings_dict["too_long_query"])[:10] |
|
warnings.warn( |
|
f"{n} out of {N} queries are too long for context length {self.max_len}, " |
|
f"e.g. at indexes {ex}. These examples will be skipped in context pruning, " |
|
"their contexts will be kept as is." + info |
|
) |
|
n = len(warnings_dict["unusually_long_query"]) |
|
if n > 0: |
|
ex = list(warnings_dict["unusually_long_query"])[:10] |
|
warnings.warn( |
|
f"{n} out of {N} queries are longer than {self.unusual_query_length} tokens, " |
|
f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " |
|
"but the quality of context pruning could be reduced." + info |
|
) |
|
n = len(warnings_dict["unusually_long_title"]) |
|
if n > 0: |
|
ex = list(warnings_dict["unusually_long_title"])[:10] |
|
warnings.warn( |
|
f"{n} out of {N} titles are longer than {self.unusual_title_length} tokens, " |
|
f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " |
|
"but the quality of context pruning could be reduced." + info |
|
) |
|
n = len(warnings_dict["split_context"]) |
|
if n > 0: |
|
ex = list(warnings_dict["split_context"])[:10] |
|
warnings.warn( |
|
f"{n} out of {N} contexts were split into several pieces for context pruning, " |
|
f"due to a limited context length of Provence which is equal to {self.max_len}. " |
|
"This could potentially reduce the quality of context pruning. " |
|
"You could consider checking and reducing lengths of contexts, queries, or titles." |
|
+ info |
|
) |
|
|