|
import gradio as gr |
|
import whisper |
|
import torch |
|
import os |
|
from pydub import AudioSegment, silence |
|
from faster_whisper import WhisperModel |
|
import numpy as np |
|
from scipy.io import wavfile |
|
from scipy.signal import correlate |
|
import tempfile |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODELS = { |
|
"Tiny (Fastest)": "tiny", |
|
"Base (Faster)": "base", |
|
"Small (Balanced)": "small", |
|
"Medium (Accurate)": "medium", |
|
"Large (Most Accurate)": "large", |
|
"Faster Whisper Large v3": "Systran/faster-whisper-large-v3" |
|
} |
|
|
|
|
|
LANGUAGE_NAME_TO_CODE = { |
|
"Auto Detect": "Auto Detect", |
|
"English": "en", |
|
"Chinese": "zh", |
|
"German": "de", |
|
"Spanish": "es", |
|
"Russian": "ru", |
|
"Korean": "ko", |
|
"French": "fr", |
|
"Japanese": "ja", |
|
"Portuguese": "pt", |
|
"Turkish": "tr", |
|
"Polish": "pl", |
|
"Catalan": "ca", |
|
"Dutch": "nl", |
|
"Arabic": "ar", |
|
"Swedish": "sv", |
|
"Italian": "it", |
|
"Indonesian": "id", |
|
"Hindi": "hi", |
|
"Finnish": "fi", |
|
"Vietnamese": "vi", |
|
"Hebrew": "he", |
|
"Ukrainian": "uk", |
|
"Greek": "el", |
|
"Malay": "ms", |
|
"Czech": "cs", |
|
"Romanian": "ro", |
|
"Danish": "da", |
|
"Hungarian": "hu", |
|
"Tamil": "ta", |
|
"Norwegian": "no", |
|
"Thai": "th", |
|
"Urdu": "ur", |
|
"Croatian": "hr", |
|
"Bulgarian": "bg", |
|
"Lithuanian": "lt", |
|
"Latin": "la", |
|
"Maori": "mi", |
|
"Malayalam": "ml", |
|
"Welsh": "cy", |
|
"Slovak": "sk", |
|
"Telugu": "te", |
|
"Persian": "fa", |
|
"Latvian": "lv", |
|
"Bengali": "bn", |
|
"Serbian": "sr", |
|
"Azerbaijani": "az", |
|
"Slovenian": "sl", |
|
"Kannada": "kn", |
|
"Estonian": "et", |
|
"Macedonian": "mk", |
|
"Breton": "br", |
|
"Basque": "eu", |
|
"Icelandic": "is", |
|
"Armenian": "hy", |
|
"Nepali": "ne", |
|
"Mongolian": "mn", |
|
"Bosnian": "bs", |
|
"Kazakh": "kk", |
|
"Albanian": "sq", |
|
"Swahili": "sw", |
|
"Galician": "gl", |
|
"Marathi": "mr", |
|
"Punjabi": "pa", |
|
"Sinhala": "si", |
|
"Khmer": "km", |
|
"Shona": "sn", |
|
"Yoruba": "yo", |
|
"Somali": "so", |
|
"Afrikaans": "af", |
|
"Occitan": "oc", |
|
"Georgian": "ka", |
|
"Belarusian": "be", |
|
"Tajik": "tg", |
|
"Sindhi": "sd", |
|
"Gujarati": "gu", |
|
"Amharic": "am", |
|
"Yiddish": "yi", |
|
"Lao": "lo", |
|
"Uzbek": "uz", |
|
"Faroese": "fo", |
|
"Haitian Creole": "ht", |
|
"Pashto": "ps", |
|
"Turkmen": "tk", |
|
"Nynorsk": "nn", |
|
"Maltese": "mt", |
|
"Sanskrit": "sa", |
|
"Luxembourgish": "lb", |
|
"Burmese": "my", |
|
"Tibetan": "bo", |
|
"Tagalog": "tl", |
|
"Malagasy": "mg", |
|
"Assamese": "as", |
|
"Tatar": "tt", |
|
"Hawaiian": "haw", |
|
"Lingala": "ln", |
|
"Hausa": "ha", |
|
"Bashkir": "ba", |
|
"Javanese": "jw", |
|
"Sundanese": "su", |
|
} |
|
|
|
|
|
CODE_TO_LANGUAGE_NAME = {v: k for k, v in LANGUAGE_NAME_TO_CODE.items()} |
|
|
|
def convert_to_wav(audio_file): |
|
"""Convert any audio file to WAV format.""" |
|
audio = AudioSegment.from_file(audio_file) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: |
|
wav_path = temp_wav.name |
|
audio.export(wav_path, format="wav") |
|
return wav_path |
|
|
|
def resample_audio(audio_segment, target_sample_rate): |
|
"""Resample an audio segment to the target sample rate.""" |
|
return audio_segment.set_frame_rate(target_sample_rate) |
|
|
|
def detect_language(audio_file): |
|
"""Detect the language of the audio file.""" |
|
if audio_file is None: |
|
return "Error: No audio file uploaded." |
|
|
|
try: |
|
|
|
wav_path = convert_to_wav(audio_file) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
compute_type = "float32" if device == "cuda" else "int8" |
|
|
|
|
|
model = WhisperModel(MODELS["Faster Whisper Large v3"], device=device, compute_type=compute_type) |
|
|
|
|
|
segments, info = model.transcribe(wav_path, task="translate", language=None) |
|
detected_language_code = info.language |
|
|
|
|
|
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") |
|
|
|
|
|
os.remove(wav_path) |
|
|
|
return f"Detected Language: {detected_language}" |
|
except Exception as e: |
|
logger.error(f"Error in detect_language: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def remove_silence(audio_file, silence_threshold=-40, min_silence_len=500): |
|
""" |
|
Remove silence from the audio file using AI-based silence detection. |
|
|
|
Args: |
|
audio_file (str): Path to the input audio file. |
|
silence_threshold (int): Silence threshold in dB. Default is -40 dB. |
|
min_silence_len (int): Minimum length of silence to remove in milliseconds. Default is 500 ms. |
|
|
|
Returns: |
|
str: Path to the output audio file with silence removed. |
|
""" |
|
if audio_file is None: |
|
return None |
|
|
|
try: |
|
|
|
wav_path = convert_to_wav(audio_file) |
|
|
|
|
|
audio = AudioSegment.from_file(wav_path) |
|
|
|
|
|
silent_chunks = silence.detect_silence( |
|
audio, |
|
min_silence_len=min_silence_len, |
|
silence_thresh=silence_threshold |
|
) |
|
|
|
|
|
non_silent_audio = AudioSegment.empty() |
|
start = 0 |
|
for chunk in silent_chunks: |
|
non_silent_audio += audio[start:chunk[0]] |
|
start = chunk[1] |
|
non_silent_audio += audio[start:] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output: |
|
output_path = temp_output.name |
|
non_silent_audio.export(output_path, format="wav") |
|
|
|
|
|
os.remove(wav_path) |
|
|
|
return output_path |
|
except Exception as e: |
|
logger.error(f"Error in remove_silence: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
def detect_and_trim_audio(main_audio, target_audio, threshold=0.5): |
|
""" |
|
Detect the target audio in the main audio and trim the main audio to include only the detected segments. |
|
|
|
Args: |
|
main_audio (str): Path to the main audio file. |
|
target_audio (str): Path to the target audio file. |
|
threshold (float): Detection threshold (0 to 1). Higher values mean stricter detection. |
|
|
|
Returns: |
|
str: Path to the trimmed audio file. |
|
str: Detected timestamps in the format "start-end (in seconds)". |
|
""" |
|
if main_audio is None or target_audio is None: |
|
return None, "Error: Please upload both main and target audio files." |
|
|
|
try: |
|
|
|
main_wav_path = convert_to_wav(main_audio) |
|
target_wav_path = convert_to_wav(target_audio) |
|
|
|
|
|
main_rate, main_data = wavfile.read(main_wav_path) |
|
target_rate, target_data = wavfile.read(target_wav_path) |
|
|
|
|
|
if main_rate != target_rate: |
|
logger.warning(f"Sample rates differ: main_audio={main_rate}, target_audio={target_rate}. Resampling target audio.") |
|
target_segment = AudioSegment.from_file(target_wav_path) |
|
target_segment = resample_audio(target_segment, main_rate) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_resampled: |
|
resampled_path = temp_resampled.name |
|
target_segment.export(resampled_path, format="wav") |
|
target_rate, target_data = wavfile.read(resampled_path) |
|
|
|
|
|
main_data = main_data.astype(np.float32) / np.iinfo(main_data.dtype).max |
|
target_data = target_data.astype(np.float32) / np.iinfo(target_data.dtype).max |
|
|
|
|
|
correlation = correlate(main_data, target_data, mode='valid') |
|
correlation = np.abs(correlation) |
|
max_corr = np.max(correlation) |
|
|
|
|
|
peak_index = np.argmax(correlation) |
|
peak_value = correlation[peak_index] |
|
|
|
|
|
if peak_value < threshold * max_corr: |
|
return None, "Error: Target audio not detected in the main audio." |
|
|
|
|
|
start_time = peak_index / main_rate |
|
end_time = (peak_index + len(target_data)) / main_rate |
|
|
|
|
|
main_audio_segment = AudioSegment.from_file(main_wav_path) |
|
start_ms = int(start_time * 1000) |
|
end_ms = int(end_time * 1000) |
|
trimmed_audio = main_audio_segment[start_ms:end_ms] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output: |
|
output_path = temp_output.name |
|
trimmed_audio.export(output_path, format="wav") |
|
|
|
|
|
timestamps_str = f"{start_time:.2f}-{end_time:.2f}" |
|
|
|
|
|
os.remove(main_wav_path) |
|
os.remove(target_wav_path) |
|
if 'resampled_path' in locals(): |
|
os.remove(resampled_path) |
|
|
|
return output_path, timestamps_str |
|
except Exception as e: |
|
logger.error(f"Error in detect_and_trim_audio: {str(e)}") |
|
return None, f"Error: {str(e)}" |
|
|
|
def transcribe_audio(audio_file, language="Auto Detect", model_size="Faster Whisper Large v3"): |
|
"""Transcribe the audio file.""" |
|
if audio_file is None: |
|
return "Error: No audio file uploaded." |
|
|
|
try: |
|
|
|
wav_path = convert_to_wav(audio_file) |
|
|
|
|
|
audio = AudioSegment.from_file(wav_path) |
|
audio = audio.set_frame_rate(16000).set_channels(1) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_processed: |
|
processed_audio_path = temp_processed.name |
|
audio.export(processed_audio_path, format="wav") |
|
|
|
|
|
if model_size == "Faster Whisper Large v3": |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
compute_type = "float32" if device == "cuda" else "int8" |
|
|
|
|
|
model = WhisperModel(MODELS[model_size], device=device, compute_type=compute_type) |
|
segments, info = model.transcribe( |
|
processed_audio_path, |
|
task="transcribe", |
|
word_timestamps=True, |
|
repetition_penalty=1.1, |
|
temperature=[0.0, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0], |
|
) |
|
transcription = " ".join([segment.text for segment in segments]) |
|
detected_language_code = info.language |
|
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") |
|
else: |
|
|
|
model = whisper.load_model(MODELS[model_size]) |
|
|
|
|
|
if language == "Auto Detect": |
|
result = model.transcribe(processed_audio_path, fp16=False) |
|
detected_language_code = result.get("language", "unknown") |
|
detected_language = CODE_TO_LANGUAGE_NAME.get(detected_language_code, "Unknown Language") |
|
else: |
|
language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") |
|
result = model.transcribe(processed_audio_path, language=language_code, fp16=False) |
|
detected_language = language |
|
|
|
transcription = result["text"] |
|
|
|
|
|
os.remove(processed_audio_path) |
|
os.remove(wav_path) |
|
|
|
|
|
return f"Detected Language: {detected_language}\n\nTranscription:\n{transcription}" |
|
except Exception as e: |
|
logger.error(f"Error in transcribe_audio: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Audio Processing Tool") |
|
|
|
with gr.Tab("Detect Language"): |
|
gr.Markdown("Upload an audio file to detect its language.") |
|
detect_audio_input = gr.Audio(type="filepath", label="Upload Audio File") |
|
detect_language_output = gr.Textbox(label="Detected Language") |
|
detect_button = gr.Button("Detect Language") |
|
|
|
with gr.Tab("Transcribe Audio"): |
|
gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.") |
|
transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File") |
|
language_dropdown = gr.Dropdown( |
|
choices=list(LANGUAGE_NAME_TO_CODE.keys()), |
|
label="Select Language", |
|
value="Auto Detect" |
|
) |
|
model_dropdown = gr.Dropdown( |
|
choices=list(MODELS.keys()), |
|
label="Select Model", |
|
value="Faster Whisper Large v3", |
|
interactive=True |
|
) |
|
transcribe_output = gr.Textbox(label="Transcription and Detected Language") |
|
transcribe_button = gr.Button("Transcribe Audio") |
|
|
|
with gr.Tab("Remove Silence"): |
|
gr.Markdown("Upload an audio file to remove silence.") |
|
silence_audio_input = gr.Audio(type="filepath", label="Upload Audio File") |
|
silence_threshold_slider = gr.Slider( |
|
minimum=-60, maximum=-20, value=-40, step=1, |
|
label="Silence Threshold (dB)", |
|
info="Lower values detect quieter sounds as silence." |
|
) |
|
min_silence_len_slider = gr.Slider( |
|
minimum=100, maximum=2000, value=500, step=100, |
|
label="Minimum Silence Length (ms)", |
|
info="Minimum duration of silence to remove." |
|
) |
|
silence_output = gr.Audio(label="Processed Audio (Silence Removed)", type="filepath") |
|
silence_button = gr.Button("Remove Silence") |
|
|
|
with gr.Tab("Detect and Trim Audio"): |
|
gr.Markdown("Upload a main audio file and a target audio file. The app will detect the target audio in the main audio and trim it.") |
|
main_audio_input = gr.Audio(type="filepath", label="Upload Main Audio File") |
|
target_audio_input = gr.Audio(type="filepath", label="Upload Target Audio File") |
|
threshold_slider = gr.Slider( |
|
minimum=0.1, maximum=1.0, value=0.5, step=0.1, |
|
label="Detection Threshold", |
|
info="Higher values mean stricter detection." |
|
) |
|
trimmed_audio_output = gr.Audio(label="Trimmed Audio", type="filepath") |
|
timestamps_output = gr.Textbox(label="Detected Timestamps (in seconds)") |
|
detect_trim_button = gr.Button("Detect and Trim") |
|
|
|
|
|
detect_button.click(detect_language, inputs=detect_audio_input, outputs=detect_language_output) |
|
transcribe_button.click( |
|
transcribe_audio, |
|
inputs=[transcribe_audio_input, language_dropdown, model_dropdown], |
|
outputs=transcribe_output |
|
) |
|
silence_button.click( |
|
remove_silence, |
|
inputs=[silence_audio_input, silence_threshold_slider, min_silence_len_slider], |
|
outputs=silence_output |
|
) |
|
detect_trim_button.click( |
|
detect_and_trim_audio, |
|
inputs=[main_audio_input, target_audio_input, threshold_slider], |
|
outputs=[trimmed_audio_output, timestamps_output] |
|
) |
|
|
|
|
|
demo.launch() |