import asyncio import base64 import json from pathlib import Path import gradio as gr import numpy as np import openai from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import HTMLResponse, StreamingResponse from fastrtc import ( AdditionalOutputs, AsyncStreamHandler, Stream, get_twilio_turn_credentials, wait_for_item, ) from gradio.utils import get_space from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent load_dotenv() cur_dir = Path(__file__).parent SAMPLE_RATE = 24000 class OpenAIHandler(AsyncStreamHandler): def __init__( self, ) -> None: super().__init__( expected_layout="mono", output_sample_rate=SAMPLE_RATE, output_frame_size=480, input_sample_rate=SAMPLE_RATE, ) self.connection = None self.output_queue = asyncio.Queue() def copy(self): return OpenAIHandler() async def start_up( self, ): """Connect to realtime API. Run forever in separate thread to keep connection open.""" self.client = openai.AsyncOpenAI() async with self.client.beta.realtime.connect( model="gpt-4o-mini-realtime-preview-2024-12-17" ) as conn: await conn.session.update( session={"turn_detection": {"type": "server_vad"}} ) self.connection = conn async for event in self.connection: if event.type == "response.audio_transcript.done": await self.output_queue.put(AdditionalOutputs(event)) if event.type == "response.audio.delta": await self.output_queue.put( ( self.output_sample_rate, np.frombuffer( base64.b64decode(event.delta), dtype=np.int16 ).reshape(1, -1), ), ) async def receive(self, frame: tuple[int, np.ndarray]) -> None: if not self.connection: return _, array = frame array = array.squeeze() audio_message = base64.b64encode(array.tobytes()).decode("utf-8") await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: return await wait_for_item(self.output_queue) async def shutdown(self) -> None: if self.connection: await self.connection.close() self.connection = None def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): chatbot.append({"role": "assistant", "content": response.transcript}) return chatbot chatbot = gr.Chatbot(type="messages") latest_message = gr.Textbox(type="text", visible=False) stream = Stream( OpenAIHandler(), mode="send-receive", modality="audio", additional_inputs=[chatbot], additional_outputs=[chatbot], additional_outputs_handler=update_chatbot, rtc_configuration=get_twilio_turn_credentials() if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=90 if get_space() else None, ) app = FastAPI() stream.mount(app) @app.get("/") async def _(): rtc_config = get_twilio_turn_credentials() if get_space() else None html_content = (cur_dir / "index.html").read_text() html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content) @app.get("/outputs") def _(webrtc_id: str): async def output_stream(): import json async for output in stream.output_stream(webrtc_id): s = json.dumps({"role": "assistant", "content": output.args[0].transcript}) yield f"event: output\ndata: {s}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") if __name__ == "__main__": import os if (mode := os.getenv("MODE")) == "UI": stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)