Multi-turn Chat History Management

#37
by Jaykumaran17 - opened

Hey Team,

I'm finding hard to properly integrate chat history as shown in the Molmo Chat WebPage.

Would be great if you could help me out with this, :>)

chat_history = []

while True:
    print(f"{color_b}************ Enter Your Input Query: *****************{reset}")
    input_query = input()
    query_text = input_query + "\n"+ f"{chat_history}" if chat_history else input_query


    print("######################")

    if input_query == exit_trigger:
        print(f"{color_r}**************************Exiting Molmo Chat!***************************{reset}")
        break


    inputs = processor.process(images=ip_img, text=query_text)
    # Move inputs to the correct device and create a batch of size 1
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}

    # Generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
    output = model.generate_from_batch(
    inputs, GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"), tokenizer=processor.tokenizer
    )

    # Only get generated tokens; decode them to text
    generated_tokens = output[0, inputs["input_ids"].size(1) :]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    # Print the generated text
    print(generated_text)


    # format_chat = f"Previous Query: {query_text}\nYour Previous Response: {generated_text}"
    format_chat = {
        "Previous Query": input_query,
        "Previous Model Response": generated_text
    }

    print("----------------------------------------------")
    print(format_chat)
    print("------------------------------------------------")

    chat_history = save_history(chat_history, save_text = format_chat, remove_prev_str = False)



def save_history(chat_history, save_text: str = '', remove_prev_str: bool = False, remove_len: int = 1):
    
    chat_history = [save_text]+chat_history
    
    if remove_prev_str:
        for _ in range(remove_len):
            chat_history.pop(0)
    
    return chat_history

Can anyone clarify where i'm doing wrong, or what's the right approach to have Conversation Memory

They way this is usually done is with a list of JSONs e.g.:
conversation = [{"role": "user", "content": "My message 1"}, {"role": "assistant", "content": "answer 1"}, {"role": "user", "content": "My message 2"}, {"role": "assistant", "content": "answer 2"}...]
And then you use a chat_template, which you can add to the tokenizer or processor, to take care of structuring it into a tokenizable conversation and do the tokenization. Like this:
query_text = processor.apply_chat_template(conversation, tokenize=False)
And then you can tokenize the way you did before:
inputs = processor.process(images=ip_img, text=query_text)

Your chat template and conversation format have to match and ideally it also matches what the model was tuned for. Since this one is Qwen based, I would recommend trying the Qwen chat template in case this processor and tokenizer do not have one.

Sign up or log in to comment