from tokenizers import Tokenizer
import onnxruntime as ort
import numpy as np

reranker_tokenizer = Tokenizer.from_file('./tokenizer.json')
reranker_session = ort.InferenceSession('./model.onnx')

def rerank(question, passages, normalize_scores=True):
    # Format input templates
    templates = [f"Query: {question}\nSentence: {passage}" for passage in passages]
    encoded_inputs = reranker_tokenizer.encode_batch(templates)

    # Convert to lists and truncate sequences to max length (32768)
    input_ids = [enc.ids[:32768] for enc in encoded_inputs]  # Truncate here
    attention_mask = [[1] * len(ids) for ids in input_ids]

    # Find max length in batch
    batch_max_length = max(len(ids) for ids in input_ids)  # Already truncated to <=512

    # Pad sequences
    def pad_sequence(seq, pad_value=0):
        return seq + [pad_value] * (batch_max_length - len(seq))

    input_ids = np.array([pad_sequence(ids) for ids in input_ids], dtype=np.int64)
    attention_mask = np.array([pad_sequence(mask, pad_value=0) for mask in attention_mask], dtype=np.int64)

    # Create ONNX input dict
    inputs_onnx = {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }

    # Run ONNX model
    outputs = reranker_session.run(None, inputs_onnx)
    logits = outputs[0]

    # Apply softmax to get probabilities
    probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)

    # Get predicted class and confidence score
    predicted_classes = np.argmax(probabilities, axis=1).tolist()
    confidences = np.max(probabilities, axis=1).tolist()

    results = [
        {"passage": passage, "prediction": pred, "confidence": conf}
        for passage, pred, conf in zip(passages, predicted_classes, confidences)
    ]

    final_results = []
    for document, result in zip(passages, results):
        # If the prediction is 0, adjust the confidence score
        if result['prediction'] == 0:
            result['confidence'] = 1 - result['confidence']
        final_results.append((document, result['confidence']))
    
    # Sort by confidence score in descending order
    sorted_results = sorted(final_results, key=lambda x: x[1], reverse=True)

    # Normalize scores if required
    if normalize_scores:
        total_score = sum(result[1] for result in sorted_results)
        if total_score > 0:
            sorted_results = [(result[0], result[1] / total_score) for result in sorted_results]

    return sorted_results

question = "O que é o Pantanal?"
passages = [
    "É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.",
    "Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.",
    "O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.",
    "O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.",
    "É um local com importância histórica e cultural para as populações locais.",
    "O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias."
]
ranked_results = rerank(question, passages, normalize_scores=True)
ranked_results
# [('O Pantanal é uma extensa planície alagável localizada na América do Sul, principalmente no Brasil, mas também em partes da Bolívia e Paraguai.',
#   0.7105862286443647),
#  ('O Pantanal é um importante habitat para diversas espécies de animais, inclusive aves migratórias.',
#   0.22660008031497725),
#  ('O Pantanal sofre com impactos ambientais, como a exploração mineral e o desmatamento.',
#   0.043374300040060654),
#  ('É um local com importância histórica e cultural para as populações locais.',
#   0.0070428120274147726),
#  ('É um dos ecossistemas mais ricos em biodiversidade do mundo, abrigando uma grande variedade de espécies animais e vegetais.',
#   0.006359544027065005),
#  ('Sua beleza natural, com rios e lagos interligados, atrai turistas de todo o mundo.',
#   0.006037034946117598)]

question = "What is the speed of light?"
passages = [
    "Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.",
    "The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.",
    "The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.",
    "The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.",
    "Light can be described as both a wave and a particle, a concept known as wave-particle duality."
]
ranked_results = rerank(question, passages, normalize_scores=True)
ranked_results
# [('The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s), or about 186,282 miles per second.',
#   0.5686758878772575),
#  ('The theory of relativity, proposed by Albert Einstein, has revolutionized our understanding of space, time, and gravity.',
#   0.14584055128478327),
#  ('The Earth orbits the Sun at an average distance of about 93 million miles, taking roughly 365.25 days to complete one revolution.',
#   0.13790743024424898),
#  ("Isaac Newton's laws of motion and gravity laid the groundwork for classical mechanics.",
#   0.08071345159269593),
#  ('Light can be described as both a wave and a particle, a concept known as wave-particle duality.',
#   0.06686267900101434)]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for cnmoro/TangledLlama33m-Reranker-EnPt-ONNX

Quantized
(1)
this model