fused-whisper-llama / README.md
johaness14's picture
Update README.md
c2d6d75 verified
metadata
language:
  - all
base_model:
  - openai/whisper-small
  - unsloth/Llama-3.2-1B-Instruct
tags:
  - speech to text
  - speech recognition

FusedWhisperLlama Model

Model ini adalah hasil fusion antara Whisper dan LLaMA untuk speech-to-text-to-LLM pipeline.

Model Description

  • Model Type: FusedWhisperLlama
  • Language: Indonesian & English
  • Tasks: Speech Recognition, Text Generation
  • Base Models:
    • Whisper: openai/whisper-small
    • LLaMA: unsloth/Llama-3.2-1B-Instruct

Usage

import torch
import torch.nn as nn
import librosa
import numpy as np
import json
import os
from typing import Dict, Any
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor, LlamaConfig, LlamaForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from pathlib import Path

def download_model_files(repo_id: str, local_dir: str):
    os.makedirs(local_dir, exist_ok=True)
    config_dir = os.path.join(local_dir, "configs")
    os.makedirs(config_dir, exist_ok=True)
    
    # Download model file
    print("Downloading model file...")
    model_path = hf_hub_download(
        repo_id=repo_id,
        filename="pytorch_model.bin",
        local_dir=local_dir
    )
    
    # Download configs
    print("Downloading config files...")
    config_files = [
        "config.json",
        "configs/config_whisper.json",
        "configs/config_llama.json",
        # Whisper tokenizer files
        "configs/tokenizer_whisper/added_tokens.json",
        "configs/tokenizer_whisper/merges.txt",
        "configs/tokenizer_whisper/normalizer.json",
        "configs/tokenizer_whisper/preprocessor_config.json",
        "configs/tokenizer_whisper/special_tokens_map.json",
        "configs/tokenizer_whisper/tokenizer_config.json",
        "configs/tokenizer_whisper/vocab.json",
        # Llama tokenizer files
        "configs/tokenizer_llama/special_tokens_map.json",
        "configs/tokenizer_llama/tokenizer.json",
        "configs/tokenizer_llama/tokenizer_config.json"
    ]
    
    for file in config_files:
        try:
            hf_hub_download(
                repo_id=repo_id,
                filename=file,
                local_dir=local_dir
            )
            print(f"Downloaded {file}")
        except Exception as e:
            print(f"Warning: Could not download {file}: {e}")
    
    return os.path.join(local_dir, "pytorch_model.bin")

class StandaloneFusionInference:
    def __init__(self, model_path: str, config_dir: str = None, device: str = None):
        if config_dir is None:
            config_dir = os.path.join(os.path.dirname(model_path), "configs")

        # Set device
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        
        print(f"Using device: {self.device}")

        # Load configs
        with open(os.path.join(config_dir, "config_whisper.json"), "r") as f:
            self.whisper_config = json.load(f)
        with open(os.path.join(config_dir, "config_llama.json"), "r") as f:
            self.llama_config = json.load(f)

        print("Loading Whisper model...")
        whisper_config = WhisperConfig(**self.whisper_config["whisper_config"])
        self.whisper = WhisperForConditionalGeneration(whisper_config)
        self.processor = WhisperProcessor.from_pretrained(
            os.path.join(config_dir, "tokenizer_whisper")
        )

        print("Loading LLaMA model...")
        llama_config = LlamaConfig(**self.llama_config["llama_config"])
        self.llm = LlamaForCausalLM(llama_config)

        # Load LLM tokenizer
        tokenizer_path = os.path.join(config_dir, "tokenizer_llama")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_path,
                trust_remote_code=True
            )
            print("Loaded local LLaMA tokenizer")
        except (OSError, ValueError) as e:
            print(f"Warning: Could not load local tokenizer ({e}), using default")
            self.tokenizer = AutoTokenizer.from_pretrained(
                "unsloth/Llama-3.2-1B-Instruct",
                trust_remote_code=True
            )

        # Fusion layer
        self.fusion_layer = nn.Sequential(
            nn.Linear(
                self.whisper.config.d_model,
                self.whisper.config.d_model
            ),
            nn.ReLU(),
            nn.LayerNorm(self.whisper.config.d_model)
        )

        print("Loading model weights...")
        weights = torch.load(model_path, map_location=self.device)
        self.whisper.load_state_dict(weights["whisper_model"])
        self.llm.load_state_dict(weights["llm_model"])
        self.fusion_layer.load_state_dict(weights["fusion_layer"])

        # Set to eval mode
        self.whisper.eval()
        self.llm.eval()
        self.fusion_layer.eval()

        # Move to device
        self.whisper = self.whisper.to(self.device)
        self.llm = self.llm.to(self.device)
        self.fusion_layer = self.fusion_layer.to(self.device)

        self.system_prompt = self.whisper_config["system_prompt"]
        print("Model loaded successfully!")

    def generate(self, audio_path: str) -> Dict[str, Any]:
        # Load dan proses audio
        speech, sr = librosa.load(audio_path, sr=16000, mono=True)
        speech = librosa.util.normalize(speech)

        # Process dengan whisper processor
        inputs = self.processor(
            speech,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(self.device)

        with torch.no_grad():
            # Get transcription
            outputs = self.whisper.generate(
                inputs,
                max_length=448,
                num_beams=5,
                temperature=0.0,
                no_repeat_ngram_size=3,
                return_timestamps=False
            )
            transcription = self.processor.batch_decode(
                outputs,
                skip_special_tokens=True,
                normalize=True
            )[0].strip()

            # Prepare input untuk LLM
            prompt = f"System: {self.system_prompt}\nUser: {transcription}"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            # Generate response
            outputs = self.llm.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True
            )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        return {
            "transcription": transcription,
            "response": response
        }

if __name__ == "__main__":
    # Download model dari Hugging Face Hub
    repo_id = "johaness14/fused-whisper-llama"
    local_dir = "downloaded_model"
    model_path = download_model_files(repo_id, local_dir)
    
    # Initialize inference
    inference = StandaloneFusionInference(
        model_path,
        config_dir=os.path.join(local_dir, "configs"),
        device="cuda"  # or "cpu" for CPU-only
    )

    # Run inference
    audio_path = "path/to/your/audio.wav"
    output = inference.generate(audio_path)

    print("\nTranscription:")
    print(output["transcription"])
    print("\nResponse:")
    print(output["response"])

Training Details

Model ini menggabungkan kemampuan speech recognition dari Whisper dengan kemampuan text generation dari LLaMA menggunakan fusion layer.

Training Procedure

  • Speech Recognition: Menggunakan Whisper small model
  • Text Generation: Menggunakan LLaMA 3.2 1B model
  • Fusion: Custom fusion layer untuk menghubungkan kedua model

Limitations and Biases

  • Model mungkin memiliki bias dari model dasar yang digunakan
  • Performa bergantung pada kualitas audio input
  • Keterbatasan pada panjang teks yang bisa digenerate