|
--- |
|
base_model: unsloth/meta-llama-3.1-8b-instruct-bnb-4bit |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- llama |
|
- trl |
|
- grpo |
|
license: apache-2.0 |
|
language: |
|
- en |
|
- tr |
|
datasets: |
|
- umarigan/OpenThoughts-43k-TR |
|
--- |
|
|
|
# Uploaded model |
|
|
|
- **Developed by:** umarigan |
|
- **License:** apache-2.0 |
|
- **Finetuned from model :** unsloth/meta-llama-3.1-8b-instruct-bnb-4bit |
|
|
|
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |
|
|
|
|
|
Eval results: |
|
arc-tr = 57.68% |
|
truthful_qa-tr = ~20%-40% |
|
|
|
following code to reproduce the results: |
|
|
|
```python |
|
|
|
import torch |
|
from transformers import pipeline |
|
from datasets import load_dataset |
|
import re |
|
import torch |
|
from transformers import pipeline |
|
|
|
model_id = "umarigan/llama-3.2-8B-R1-Tr" |
|
pipe = pipeline( |
|
"text-generation", |
|
model=model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
) |
|
|
|
|
|
#ARC-TR |
|
ds = load_dataset("mukayese/arc-tr", split ='test') |
|
|
|
def extract_answer(text): |
|
"""Extract first occurring A-D label from generated text""" |
|
match = re.search(r'\b([A-D])\b', text, re.IGNORECASE) |
|
return match.group(1).upper() if match else None |
|
|
|
total = 0 |
|
correct = 0 |
|
|
|
for example in ds: |
|
# Format the question and choices |
|
question = example["question"] |
|
choices = "\n".join([f"{label}) {text}" for label, text in |
|
zip(example["choices"]["label"], example["choices"]["text"])]) |
|
|
|
# Create prompt with explicit instruction |
|
prompt = f"""Answer this multiple-choice question by providing ONLY the letter corresponding to the correct answer (A, B, C, or D). Do not include any explanation. |
|
|
|
Question: {question} |
|
Options: |
|
{choices} |
|
Answer:""" |
|
|
|
# Generate response |
|
messages = [{"role": "user", "content": prompt}] |
|
try: |
|
outputs = pipe( |
|
messages, |
|
max_new_tokens=5, # Limit response length to get just the answer |
|
do_sample=False # Disable sampling for more deterministic answers |
|
) |
|
response = outputs[0]["generated_text"][-1]['content'] |
|
predicted = extract_answer(response) |
|
answer = example["answerKey"] |
|
|
|
# Update counters |
|
total += 1 |
|
if predicted == answer: |
|
correct += 1 |
|
|
|
except Exception as e: |
|
print(f"Error processing example: {e}") |
|
continue |
|
|
|
# Print results |
|
print(f"\nBenchmark Results:") |
|
print(f"Total questions processed: {total}") |
|
print(f"Correct answers: {correct}") |
|
print(f"Accuracy: {correct/total:.2%}" if total > 0 else "No questions processed") |
|
#output |
|
#Benchmark Results: |
|
#Total questions processed: 1172 |
|
#Correct answers: 676 |
|
#Accuracy: 57.68% |
|
|
|
|
|
#TRUTHFUL-TR |
|
|
|
import re |
|
ds2 = load_dataset("mukayese/truthful_qa-tr", split ='validation') |
|
def evaluate_mc(example, targets_key="mc1_targets"): |
|
"""Evaluate a single multiple-choice example with variable choices""" |
|
question = example["question"] |
|
choices = example[targets_key]["choices"] |
|
labels = example[targets_key]["labels"] |
|
|
|
# Generate option labels dynamically (A, B, C, ..., G) |
|
option_labels = [chr(65 + i) for i in range(len(choices))] |
|
|
|
# Create prompt with explicit instruction |
|
options_text = "\n".join([f"{label}) {text}" for label, text in zip(option_labels, choices)]) |
|
prompt = f"""Answer this multiple-choice question by selecting the most correct option. Provide only the letter corresponding to your choice ({', '.join(option_labels)}). |
|
|
|
Question: {question} |
|
Options: |
|
{options_text} |
|
Answer:""" |
|
|
|
# Generate response |
|
messages = [{"role": "user", "content": prompt}] |
|
try: |
|
outputs = pipe( |
|
messages, |
|
max_new_tokens=5, # Limit response length to get just the answer |
|
do_sample=False # Disable sampling for more deterministic answers |
|
) |
|
response = outputs[0]["generated_text"][-1]['content'] |
|
|
|
# Extract predicted label |
|
predicted = extract_answer(response, option_labels) |
|
if predicted is None: |
|
return 0 # Count as incorrect if no valid answer |
|
|
|
# Get correct answer |
|
correct_idx = labels.index(1) |
|
correct_label = option_labels[correct_idx] |
|
|
|
return int(predicted == correct_label) |
|
|
|
except Exception as e: |
|
print(f"Error processing example: {e}") |
|
return 0 |
|
|
|
def extract_answer(text, valid_labels): |
|
"""Extract first occurring valid label from generated text""" |
|
# Create regex pattern that matches any of the valid labels |
|
pattern = r'\b(' + '|'.join(valid_labels) + r')\b' |
|
match = re.search(pattern, text, re.IGNORECASE) |
|
return match.group(1).upper() if match else None |
|
|
|
# Evaluate on both mc1 and mc2 targets |
|
mc1_scores = [] |
|
mc2_scores = [] |
|
|
|
for example in ds2: |
|
mc1_scores.append(evaluate_mc(example, "mc1_targets")) |
|
mc2_scores.append(evaluate_mc(example, "mc2_targets")) |
|
|
|
# Calculate metrics |
|
def calculate_metrics(scores): |
|
total = len(scores) |
|
correct = sum(scores) |
|
accuracy = correct / total if total > 0 else 0 |
|
return total, correct, accuracy |
|
|
|
mc1_total, mc1_correct, mc1_accuracy = calculate_metrics(mc1_scores) |
|
mc2_total, mc2_correct, mc2_accuracy = calculate_metrics(mc2_scores) |
|
|
|
# Print results |
|
print("\nBenchmark Results:") |
|
print(f"MC1 Targets:") |
|
print(f"Total questions: {mc1_total}") |
|
print(f"Correct answers: {mc1_correct}") |
|
print(f"Accuracy: {mc1_accuracy:.2%}") |
|
print(f"\nMC2 Targets:") |
|
print(f"Total questions: {mc2_total}") |
|
print(f"Correct answers: {mc2_correct}") |
|
print(f"Accuracy: {mc2_accuracy:.2%}") |
|
|
|
#output |
|
#MC1 Targets: |
|
#Total questions: 817 |
|
#Correct answers: 355 |
|
#Accuracy: 43.45% |
|
|
|
#MC2 Targets: |
|
#Total questions: 817 |
|
#Correct answers: 181 |
|
#Accuracy: 22.15 |