Emmanuel08 commited on
Commit
283bd52
Β·
verified Β·
1 Parent(s): b264b8d

This is the optimised code

Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ import gradio as gr
7
+ from transformers import (AutoProcessor, AutoModelForCTC,
8
+ AutoModelForTokenClassification, AutoTokenizer)
9
+ from speechbrain.inference.VAD import VAD
10
+
11
+ # πŸ”§ Check for CUDA
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # πŸ›  Load Voice Activity Detection (VAD) model
15
+ vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model")
16
+
17
+ # πŸ” Function to clean up memory
18
+ def clean_up_memory():
19
+ gc.collect()
20
+ if torch.cuda.is_available():
21
+ torch.cuda.empty_cache()
22
+
23
+ # πŸŽ™ Load Wav2Vec2 ASR model
24
+ asr_model_name = "facebook/wav2vec2-large-960h"
25
+ processor = AutoProcessor.from_pretrained(asr_model_name)
26
+ w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device)
27
+ w2v2_model.eval()
28
+
29
+ # ✍ Load model for punctuation restoration
30
+ recap_model_name = "kredor/punctuate-all"
31
+ recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name)
32
+ recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device)
33
+ recap_model.eval()
34
+
35
+ # πŸ“Œ Function to add punctuation
36
+ def recap_sentence(string):
37
+ tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device)
38
+ with torch.no_grad():
39
+ predictions = recap_model(**tokens).logits
40
+
41
+ predicted_ids = torch.argmax(predictions, dim=-1)[0]
42
+ words = string.split()
43
+ punctuated_text = []
44
+
45
+ for word, pred in zip(words, predicted_ids):
46
+ punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0])
47
+
48
+ return " ".join(punctuated_text)
49
+
50
+ # 🎧 Function for chunk-based streaming transcription
51
+ def transcribe_audio_stream(audio_file, chunk_size=2.0):
52
+ audio, sr = librosa.load(audio_file, sr=16000)
53
+ duration = librosa.get_duration(y=audio, sr=sr)
54
+ transcriptions = []
55
+
56
+ for start in np.arange(0, duration, chunk_size):
57
+ end = min(start + chunk_size, duration)
58
+ chunk = audio[int(start * sr):int(end * sr)]
59
+
60
+ input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device)
61
+
62
+ with torch.no_grad():
63
+ logits = w2v2_model(input_values).logits
64
+
65
+ predicted_ids = torch.argmax(logits, dim=-1)
66
+ transcription = processor.batch_decode(predicted_ids)[0]
67
+ transcriptions.append(transcription)
68
+
69
+ return " ".join(transcriptions)
70
+
71
+ # πŸŽ™ Handle both live audio & file uploads
72
+ def return_prediction_w2v2(file_or_mic):
73
+ if not file_or_mic:
74
+ return "", "empty.txt"
75
+
76
+ # Transcribe file
77
+ transcription = transcribe_audio_stream(file_or_mic)
78
+
79
+ # Add punctuation
80
+ recap_result = recap_sentence(transcription)
81
+
82
+ # Save result to file
83
+ download_path = "transcription.txt"
84
+ with open(download_path, "w") as f:
85
+ f.write(recap_result)
86
+
87
+ clean_up_memory()
88
+ return recap_result, download_path
89
+
90
+ # πŸ–₯ Gradio Interface
91
+ mic_transcribe = gr.Interface(
92
+ fn=return_prediction_w2v2,
93
+ inputs=gr.Audio(sources="microphone", type="filepath"),
94
+ outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")],
95
+ allow_flagging="never",
96
+ live=True
97
+ )
98
+
99
+ file_transcribe = gr.Interface(
100
+ fn=return_prediction_w2v2,
101
+ inputs=gr.Audio(sources="upload", type="filepath"),
102
+ outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")],
103
+ allow_flagging="never",
104
+ live=False
105
+ )
106
+
107
+ # πŸŽ› Combine into a Gradio app
108
+ with gr.Blocks() as transcriber_app:
109
+ gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>")
110
+ gr.TabbedInterface([mic_transcribe, file_transcribe],
111
+ ["Real-Time (Microphone)", "Upload Audio"])
112
+
113
+ # πŸš€ Run the Gradio app
114
+ if __name__ == "__main__":
115
+ transcriber_app.launch()