bencser commited on
Commit
660f424
·
verified ·
1 Parent(s): 5774bea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import whisper
3
+ import yt_dlp
4
+ import os
5
+ import traceback
6
+ from pydub import AudioSegment
7
+ from threading import Thread
8
+ from queue import Queue
9
+
10
+ # Global variable to store the selected model
11
+ selected_model = None
12
+
13
+ def load_whisper_model(model_name):
14
+ global selected_model
15
+ selected_model = whisper.load_model(model_name)
16
+ return f"Loaded {model_name} model"
17
+
18
+ def chunk_audio(audio_file, chunk_size_ms=30000):
19
+ audio = AudioSegment.from_file(audio_file)
20
+ chunks = [audio[i:i+chunk_size_ms] for i in range(0, len(audio), chunk_size_ms)]
21
+ return chunks
22
+
23
+ def stream_transcription(audio_file):
24
+ segment_queue = Queue()
25
+
26
+ def transcribe_worker():
27
+ try:
28
+ chunks = chunk_audio(audio_file)
29
+ for i, chunk in enumerate(chunks):
30
+ chunk_file = f"temp_chunk_{i}.wav"
31
+ chunk.export(chunk_file, format="wav")
32
+ result = selected_model.transcribe(chunk_file)
33
+ os.remove(chunk_file)
34
+ for segment in result['segments']:
35
+ segment_text = f"[{segment['start'] + i*30:.2f}s -> {segment['end'] + i*30:.2f}s] {segment['text']}\n"
36
+ segment_queue.put(segment_text)
37
+ segment_queue.put(None) # Signal end of transcription
38
+ except Exception as e:
39
+ segment_queue.put(f"Error: {str(e)}")
40
+ segment_queue.put(None)
41
+
42
+ Thread(target=transcribe_worker).start()
43
+
44
+ full_transcript = ""
45
+ while True:
46
+ segment_text = segment_queue.get()
47
+ if segment_text is None:
48
+ break
49
+ if segment_text.startswith("Error"):
50
+ yield segment_text
51
+ break
52
+ full_transcript += segment_text
53
+ yield full_transcript
54
+
55
+ def download_youtube_audio(youtube_url):
56
+ ydl_opts = {
57
+ 'format': 'bestaudio/best',
58
+ 'postprocessors': [{
59
+ 'key': 'FFmpegExtractAudio',
60
+ 'preferredcodec': 'mp3',
61
+ 'preferredquality': '192',
62
+ }],
63
+ 'outtmpl': 'temp_audio.%(ext)s',
64
+ }
65
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
66
+ ydl.download([youtube_url])
67
+ return "temp_audio.mp3"
68
+
69
+ def process_input(model, input_type, youtube_url=None, audio_file=None):
70
+ try:
71
+ yield "Loading Whisper model..."
72
+ load_whisper_model(model)
73
+ yield f"Loaded {model} model. "
74
+
75
+ if input_type == "YouTube URL":
76
+ if youtube_url:
77
+ yield "Downloading audio from YouTube..."
78
+ audio_file = download_youtube_audio(youtube_url)
79
+ yield "Download complete. Starting transcription...\n"
80
+ else:
81
+ yield "Please provide a valid YouTube URL."
82
+ return
83
+ elif input_type == "Audio File":
84
+ if not audio_file:
85
+ yield "Please upload an audio file."
86
+ return
87
+ else:
88
+ yield "Starting transcription...\n"
89
+
90
+ yield from stream_transcription(audio_file)
91
+ except Exception as e:
92
+ error_msg = f"An error occurred: {str(e)}\n"
93
+ error_msg += traceback.format_exc()
94
+ print(error_msg)
95
+ yield f"Error: {str(e)}"
96
+ finally:
97
+ if input_type == "YouTube URL" and audio_file:
98
+ os.remove(audio_file)
99
+ # Define the Gradio interface
100
+ with gr.Blocks() as iface:
101
+ gr.Markdown("# Whisper Transcription App")
102
+ gr.Markdown("Transcribe YouTube videos or audio files using OpenAI's Whisper model. Large files and long videos can take a very long time to process.")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ model = gr.Radio(
107
+ choices=["tiny", "base", "small", "medium", "large"],
108
+ label="Whisper Model",
109
+ value="base"
110
+ )
111
+ gr.Markdown("""
112
+ - tiny: very fast, less accurate
113
+ - base: medium speed and accuracy
114
+ - small: balanced speed and accuracy
115
+ - medium: more accurate, slower
116
+ - large: most accurate, very slow
117
+ """)
118
+
119
+ input_type = gr.Radio(["YouTube URL", "Audio File"], label="Input Type")
120
+ youtube_url = gr.Textbox(label="YouTube URL")
121
+ audio_file = gr.Audio(label="Audio File", type="filepath")
122
+
123
+ with gr.Row():
124
+ submit_button = gr.Button("Submit")
125
+ clear_button = gr.Button("Clear")
126
+
127
+ with gr.Column():
128
+ output = gr.Textbox(label="Transcription", lines=25)
129
+
130
+ submit_button.click(
131
+ fn=process_input,
132
+ inputs=[model, input_type, youtube_url, audio_file],
133
+ outputs=output,
134
+ api_name="transcribe"
135
+ )
136
+
137
+ def clear_outputs():
138
+ return {youtube_url: "", audio_file: None, output: ""}
139
+
140
+ clear_button.click(
141
+ fn=clear_outputs,
142
+ inputs=[],
143
+ outputs=[youtube_url, audio_file, output],
144
+ api_name="clear"
145
+ )
146
+
147
+ # Launch the interface
148
+ iface.queue().launch(share=True)