HF Tokenizer output does not match MistralTokenizer output

#27
by bhuvan-nv - opened

The output tokens from the HF tokenizer don't match the MistralTokenizer tokens when using the tool_use chat template. The additional tokens like [AVAILABLE_TOOLS] don't get tokenized correctly with HF.

transformers version 4.40.1
mistral_common version 1.0.2

Example:

from transformers import AutoTokenizer
from mistral_common.protocol.instruct.messages import (
    AssistantMessage,
    UserMessage,
)
from mistral_common.protocol.instruct.tool_calls import (
    Tool,
    Function,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.instruct.normalize import ChatCompletionRequest

# Init tokenizers
mistral_tokenizer = MistralTokenizer.v3()
hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")

# Message History
mistral_query = ChatCompletionRequest(
    tools=[
        Tool(
            function=Function(
                name="get_current_weather",
                description="Get the current weather",
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location.",
                        },
                    },
                    "required": ["location", "format"],
                },
            )
        )
    ],
    messages=[
        UserMessage(content="What's the weather like today in Paris"),
    ],
    model="test",
)

# Tokens from mistral tokenizer
mistral_encoded = mistral_tokenizer.encode_chat_completion(mistral_query)

# Convert to HF format
hf_messages = mistral_query.model_dump()['messages']
hf_tools = mistral_query.model_dump()['tools']
hf_encoded_text = hf_tokenizer.apply_chat_template(hf_messages, chat_template="tool_use", tools=hf_tools, tokenize=False)
hf_encoded_tokens = hf_tokenizer.apply_chat_template(hf_messages, chat_template="tool_use", tools=hf_tools, tokenize=True)

# Print formatted text
print(mistral_encoded.text)
# Prints:
# <s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_current_weather",▁"description":▁"Get▁the▁current▁weather",▁"parameters":▁{"type":▁"object",▁"properties":▁{"location":▁{"type":▁"string",▁"description":▁"The▁city▁and▁state,▁e.g.▁San▁Francisco,▁CA"},▁"format":▁{"type":▁"string",▁"enum":▁["celsius",▁"fahrenheit"],▁"description":▁"The▁temperature▁unit▁to▁use.▁Infer▁this▁from▁the▁users▁location."}},▁"required":▁["location",▁"format"]}}}][/AVAILABLE_TOOLS][INST]▁What's▁the▁weather▁like▁today▁in▁Paris[/INST]

print(hf_encoded_text)
# Prints:
# <s>[AVAILABLE_TOOLS][{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'format': {'type': 'string', 'enum': ['celsius', 'fahrenheit'], 'description': 'The temperature unit to use. Infer this from the users location.'}}, 'required': ['location', 'format']}}}][/AVAILABLE_TOOLS][INST]What's the weather like today in Paris[/INST]

# Print tokens
print(mistral_encoded.tokens)
# Prints:
# [1, 6, 1501, 7567, 1891, 2032, 1113, 3396, 1316, 1113, 3396, 2032, 10598, 1629, 2032, 1113, 1295, 29498, 3790, 29498, 1537, 1991, 1316, 1113, 7286, 2032, 1113, 2226, 1040, 2636, 8854, 1316, 1113, 12206, 2032, 10598, 1891, 2032, 1113, 3582, 1316, 1113, 11491, 2032, 10598, 3501, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 7286, 2032, 1113, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 8474, 1113, 4530, 2032, 10598, 1891, 2032, 1113, 2195, 1316, 1113, 10825, 2032, 8135, 29485, 1958, 3938, 1316, 1113, 29490, 19425, 13075, 9651, 1113, 7286, 2032, 1113, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 1379, 11549, 1113, 11661, 2032, 8135, 3501, 1316, 1113, 4530, 3010, 1743, 10925, 7, 3, 2592, 29510, 29481, 1040, 8854, 1505, 3922, 1065, 6233, 4]

print(hf_encoded_tokens)
# Prints:
# [1, 29560, 9792, 27531, 5095, 29498, 4725, 3832, 29503, 4096, 19205, 1891, 2637, 1232, 3396, 1415, 1232, 3396, 2637, 12780, 1629, 2637, 1232, 1295, 29498, 3790, 29498, 1537, 1991, 1415, 1232, 7286, 2637, 1232, 2226, 1040, 2636, 8854, 1415, 1232, 12206, 2637, 12780, 1891, 2637, 1232, 3582, 1415, 1232, 11491, 2637, 12780, 3501, 2637, 12780, 1891, 2637, 1232, 2195, 1415, 1232, 7286, 2637, 1232, 1782, 3758, 1072, 2433, 29493, 1085, 29491, 29489, 29491, 4420, 10454, 29493, 10229, 16809, 1232, 4530, 2637, 12780, 1891, 2637, 1232, 2195, 1415, 1232, 10825, 2637, 6704, 29485, 1958, 3938, 1415, 1232, 29490, 19425, 13075, 6575, 1232, 7286, 2637, 1232, 1782, 8409, 5796, 1066, 1706, 29491, 1328, 1410, 1224, 1245, 1040, 6211, 5491, 2583, 11549, 1232, 11661, 2637, 6704, 3501, 1415, 1232, 4530, 2189, 14879, 4096, 29516, 9792, 27531, 5095, 29498, 4725, 3832, 29503, 4096, 17057, 29561, 3963, 29510, 29481, 1040, 8854, 1505, 3922, 1065, 6233, 29560, 29516, 17057, 29561]

assert(mistral_encoded.tokens == hf_encoded_tokens)
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# AssertionError

Hi @bhuvan-nv , the problem here is that special tokens like [AVAILABLE_TOOLS] are not being picked up correctly by the tokenizer, and instead being broken up into subwords. This is strange because they are in fact in the tokenizer! We're investigating.

This happens without tool usage as well:

>>> mistral_query = ChatCompletionRequest(messages=[UserMessage(content="Hello world")])
>>> mistral_tokenizer.encode_chat_completion(mistral_query)
Tokenized(tokens=[1, 3, 23325, 2294, 4], text='<s>[INST]▁Hello▁world[/INST]')

>>> hf_tokenizer("<s>[INST] Hello world[/INST]")
{'input_ids': [1, 29560, 17057, 29561, 23325, 2294, 29560, 29516, 17057, 29561], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Yeah, I noticed when I quantized this that is had 2-2.5x the perplexity of wizard-lm-2:8x22b and eurux-nca:8x22b (both off the same base model).

Hi all, this was actually an issue with the chat template. The default chat template should be correct, but the tool_use template had issues. I fixed it as part of a general improvement to tool use capabilities in transformers. The PR is here: https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/discussions/33, but note that this is all quite experimental for now!

That seems odd to me - if it was an issue solely in the chat template, calling the tokenizer directly as in my earlier comment should still output the same tokens, right?

@zhuexe The problem is that special tokens are only identified correctly when separated by spaces. The spaces are not included in the final output, so "[INST] Hello" is tokenized as ["[INST]", "Hello"].

Therefore, chat templates need to include this spacing to ensure special tokens are detected correctly. The main template for Mixtral-8x22B did this correctly, and I presume Mistral's own tokenizer handled things correctly, but the HF tool use template did not. I fixed it here, in addition to adding support for our secret, not-yet-announced, upcoming tool use API!

Sign up or log in to comment