qminh369 commited on
Commit
26827a2
·
verified ·
1 Parent(s): e9d670f

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +88 -0
  2. core_utils_llmlingua2.py +149 -0
  3. requirements.txt +5 -0
  4. utils_llmlingua2_test.py +0 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ #from llmlingua import PromptCompressor
4
+ from utils_llmlingua2_test import PromptCompressor
5
+ import tiktoken
6
+
7
+ compressors = {
8
+ "xlm-roberta": PromptCompressor(
9
+ #model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
10
+ #model_name="qminh369/token-classification-llmlingua2-xlm-roberta-10k_merge_10_epoch_paper",
11
+ #model_name='qminh369/token-classification-llmlingua2-xlm-roberta-42k_merge_1_epoch',
12
+ model_name='qminh369/token-classification-llmlingua2-xlm-roberta-42k_merge_10_epoch',
13
+ use_llmlingua2=True,
14
+ device_map="cpu"
15
+ )
16
+ }
17
+
18
+ tokenizer = tiktoken.encoding_for_model("gpt-4")
19
+
20
+ def compress(original_prompt, compression_rate, base_model="xlm-roberta", force_tokens = ['. ', ', '], chunk_end_tokens=['.', '\n']):
21
+ if '\\n' in force_tokens:
22
+ idx = force_tokens.index('\\n')
23
+ force_tokens[idx] = '\n'
24
+
25
+ compressor = compressors.get(base_model, compressors["xlm-roberta"])
26
+ results = compressor.compress_prompt_llmlingua2(
27
+ original_prompt,
28
+ rate=compression_rate,
29
+ force_tokens=force_tokens,
30
+ chunk_end_tokens=chunk_end_tokens,
31
+ return_word_label=True,
32
+ drop_consecutive=True,
33
+ force_reserve_digit=True,
34
+ )
35
+
36
+ compressed_prompt = results["compressed_prompt"]
37
+ n_word_compressed = len(tokenizer.encode(compressed_prompt))
38
+
39
+ word_sep = "\t\t|\t\t"
40
+ label_sep = " "
41
+ lines = results["fn_labeled_original_prompt"].split(word_sep)
42
+ preserved_tokens = []
43
+ for line in lines:
44
+ word, label = line.split(label_sep)
45
+ preserved_tokens.append((word, '+') if label == '1' else (word, None))
46
+
47
+ return compressed_prompt, preserved_tokens, n_word_compressed
48
+
49
+ title = "LLMLingua-2"
50
+
51
+ header = """# LLMLingua-2
52
+ """
53
+
54
+ theme = "soft"
55
+ css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
56
+ #anno-img .mask.active {opacity: 0.7}"""
57
+
58
+ original_prompt_text = """"""
59
+
60
+ with gr.Blocks(title=title, css=css) as app:
61
+ gr.Markdown(header)
62
+ with gr.Row():
63
+ with gr.Column(scale=3):
64
+ original_prompt = gr.Textbox(value=original_prompt_text, label="Original Prompt", lines=10, max_lines=10, interactive=True)
65
+ compressed_prompt = gr.Textbox(value='', label="Compressed Prompt", lines=10, max_lines=10, interactive=False)
66
+
67
+ with gr.Column(scale=1):
68
+ base_model = gr.Radio(["xlm-roberta"], label="Base Model", value="xlm-roberta", interactive=True)
69
+ force_tokens = gr.Dropdown(['\\n', '.', '!', '?', ','],
70
+ label="Tokens to Preserve",
71
+ value=['\\n', '.', '!', '?', ','],
72
+ multiselect=True,
73
+ interactive=True)
74
+ compression_rate = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Compression rate", info="after compr. / befor compr.", interactive=True)
75
+ n_word_original = gr.Textbox(lines=1, label="Original (GPT-4 Tokens)", interactive=False, value=len(tokenizer.encode(original_prompt_text)))
76
+ n_word_compressed = gr.Textbox(lines=1, label="Compressed (GPT-4 Tokens)", interactive=False)
77
+ button = gr.Button("⚡Click to Compress")
78
+ with gr.Accordion(label="Compression Details", open=False):
79
+ diff_text = gr.HighlightedText(label="Diff", combine_adjacent=False, show_legend=True, color_map={"+": "green"})
80
+
81
+ original_prompt.change(lambda x: len(tokenizer.encode(x)), inputs=[original_prompt], outputs=[n_word_original])
82
+ original_prompt.change(lambda x: ("", "", []), inputs=[original_prompt], outputs=[compressed_prompt, n_word_compressed, diff_text])
83
+
84
+ button.click(fn=compress,
85
+ inputs=[original_prompt, compression_rate, base_model, force_tokens],
86
+ outputs=[compressed_prompt, diff_text, n_word_compressed])
87
+
88
+ app.queue(max_size=10, api_open=False).launch(show_api=False)
core_utils_llmlingua2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import string
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+
9
+ class TokenClfDataset(Dataset): # Hàm tạo custom dataset
10
+ def __init__(
11
+ self,
12
+ texts,
13
+ max_len=512, # 256 (phobert) 512 (xlm-roberta)
14
+ tokenizer=None,
15
+ model_name="m_bert",
16
+ ):
17
+ self.len = len(texts)
18
+ self.texts = texts
19
+ self.tokenizer = tokenizer
20
+ self.max_len = max_len
21
+ self.model_name = model_name
22
+ if "m_bert" in model_name:
23
+ self.cls_token = "[CLS]"
24
+ self.sep_token = "[SEP]"
25
+ self.unk_token = "[UNK]"
26
+ self.pad_token = "[PAD]"
27
+ self.mask_token = "[MASK]"
28
+ elif "xlm-roberta-large" in model_name:
29
+ self.bos_token = "<s>"
30
+ self.eos_token = "</s>"
31
+ self.sep_token = "</s>"
32
+ self.cls_token = "<s>"
33
+ self.unk_token = "<unk>"
34
+ self.pad_token = "<pad>"
35
+ self.mask_token = "<mask>"
36
+ elif "xlm-roberta" in model_name:
37
+ self.bos_token = "<s>"
38
+ self.eos_token = "</s>"
39
+ self.sep_token = "</s>"
40
+ self.cls_token = "<s>"
41
+ self.unk_token = "<unk>"
42
+ self.pad_token = "<pad>"
43
+ self.mask_token = "<mask>"
44
+ elif "phobert" in model_name:
45
+ self.bos_token = "<s>"
46
+ self.eos_token = "</s>"
47
+ self.sep_token = "</s>"
48
+ self.cls_token = "<s>"
49
+ self.unk_token = "<unk>"
50
+ self.pad_token = "<pad>"
51
+ self.mask_token = "<mask>"
52
+ #else: raise NotImplementedError()
53
+
54
+ def __getitem__(self, index):
55
+ text = self.texts[index]
56
+ tokenized_text = self.tokenizer.tokenize(text)
57
+
58
+ tokenized_text = (
59
+ [self.cls_token] + tokenized_text + [self.sep_token]
60
+ ) # add special tokens
61
+
62
+ if len(tokenized_text) > self.max_len:
63
+ tokenized_text = tokenized_text[: self.max_len]
64
+ else:
65
+ tokenized_text = tokenized_text + [
66
+ self.pad_token for _ in range(self.max_len - len(tokenized_text))
67
+ ]
68
+
69
+ attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
70
+
71
+ ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
72
+
73
+ return {
74
+ "ids": torch.tensor(ids, dtype=torch.long),
75
+ "mask": torch.tensor(attn_mask, dtype=torch.long),
76
+ }
77
+
78
+ def __len__(self):
79
+ return self.len
80
+
81
+
82
+ def seed_everything(seed: int):
83
+ random.seed(seed)
84
+ os.environ["PYTHONHASHSEED"] = str(seed)
85
+ np.random.seed(seed)
86
+ torch.manual_seed(seed)
87
+ torch.cuda.manual_seed(seed)
88
+ torch.backends.cudnn.deterministic = True
89
+ torch.backends.cudnn.benchmark = False
90
+
91
+
92
+ def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới
93
+ if "m_bert" in model_name:
94
+ if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
95
+ token_map.values()
96
+ ):
97
+ return True
98
+ return not token.startswith("##")
99
+ elif "xlm-roberta-large" in model_name:
100
+ #print("xlm-roberta-large")
101
+ if (
102
+ token in string.punctuation
103
+ or token in force_tokens
104
+ or token in set(token_map.values())
105
+ ):
106
+ return True
107
+ return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False
108
+ elif "xlm-roberta" in model_name:
109
+ #print("xlm-roberta-large")
110
+ if (
111
+ token in string.punctuation
112
+ or token in force_tokens
113
+ or token in set(token_map.values())
114
+ ):
115
+ return True
116
+ return token.startswith("▁")
117
+ elif "phobert" in model_name:
118
+ #print("minh phobert")
119
+ #print("xlm-roberta-large")
120
+ if (
121
+ token in string.punctuation # điều kiện hoặc
122
+ or token in force_tokens
123
+ or token in set(token_map.values())
124
+ ):
125
+ return True
126
+ #return token.startswith("▁") #
127
+ #return not token.startswith("▁")
128
+ #return not token.startswith("@@")
129
+ return not token.endswith("@@")
130
+ #return token.startswith("@@")
131
+ #else: raise NotImplementedError()
132
+
133
+ def replace_added_token(token, token_map):
134
+ for ori_token, new_token in token_map.items():
135
+ token = token.replace(new_token, ori_token)
136
+ return token
137
+
138
+ def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword)
139
+ if "m_bert" in model_name:
140
+ return token.lstrip("##")
141
+ elif "xlm-roberta-large" in model_name:
142
+ return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
143
+ elif "xlm-roberta" in model_name:
144
+ return token.lstrip("▁") # bỏ kí tự "_" ở ph��a bên trái của từ
145
+ elif "phobert" in model_name:
146
+ #return token.lstrip("▁")
147
+ #return token.lstrip("@@")
148
+ return token.rstrip("@@")
149
+ # else: raise NotImplementedError()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ accelerate
3
+ tiktoken
4
+ nltk
5
+ transformers
utils_llmlingua2_test.py ADDED
The diff for this file is too large to render. See raw diff