import os
import shutil
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, normalizers, pre_tokenizers, processors, decoders
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from utils import batch_dataset_iterator
from core_base_datasets import core_base_datasets
from core_instruct_datasets import core_instruct_datasets
tokenizer_path = '../tokenizer'
if os.path.exists(tokenizer_path):
shutil.rmtree(tokenizer_path)
os.makedirs(tokenizer_path, exist_ok=True)
#
# special_tokens
#
bos_token = '<|endoftext|>'
eos_token = '<|im_end|>'
pad_token = '<|pad|>'
unk_token = '<|unk|>'
special_tokens = [
bos_token,
eos_token,
pad_token,
unk_token,
'<|im_start|>',
'<|im_sep|>',
'system',
'user',
'assistant',
'',
'',
'',
'',
'',
'',
'',
'',
'',
'',
]
for i in range(64 - len(special_tokens)):
special_tokens.append(f'<|reserved_{i}|>')
#
# BPE Tokenizer
#
bpe = BPE(unk_token=None, byte_fallback=True)
tokenizer = Tokenizer(bpe)
# normalizer
tokenizer.normalizer = None
# pre-tokenizer
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True)
# post-processor
tokenizer.post_processor = processors.ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True)
# decoder
tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)
#
# BPE Trainer
#
trainer = BpeTrainer(
vocab_size=131072, # 128 * 1024
min_frequency=3,
special_tokens=special_tokens,
max_token_length=16,
)
tokenizer_datasets = core_base_datasets + core_instruct_datasets
tokenizer.train_from_iterator(
(batch_dataset_iterator(n) for n in tokenizer_datasets),
trainer,
)
tokenizer.save(os.path.join(tokenizer_path, 'tokenizer.json'))
tokenizer.model.save(tokenizer_path)
#
# PreTrainedTokenizerFast
#
CHAT_TEMPLATE = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '<|im_sep|>' + message['content'] + '<|im_end|>'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant<|im_sep|>' }}"
"{% endif %}"
)
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
chat_template=CHAT_TEMPLATE,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=False,
)
fast_tokenizer.save_pretrained(tokenizer_path)