hysts HF staff commited on
Commit
8c42f72
·
1 Parent(s): d39aaee

Fix type annotation

Browse files
Files changed (1) hide show
  1. utils_ai_gradio.py +9 -7
utils_ai_gradio.py CHANGED
@@ -1,14 +1,15 @@
1
  import gradio as gr
2
 
 
3
  def get_app(
4
  models: list[str],
5
  default_model: str,
6
  dropdown_label: str = "Select Hyperbolic Model",
7
- choices: list[str] = None,
8
- **kwargs,
9
  ) -> gr.Blocks:
10
  display_choices = choices if choices is not None else models
11
-
12
  def update_model(new_model: str) -> list[gr.Column]:
13
  if choices is not None:
14
  idx = display_choices.index(new_model)
@@ -17,16 +18,17 @@ def get_app(
17
 
18
  with gr.Blocks(fill_height=True) as demo:
19
  model = gr.Dropdown(
20
- label=dropdown_label,
21
  choices=display_choices,
22
- value=choices[models.index(default_model)] if choices else default_model
23
  )
24
 
25
  columns = []
26
  for model_name in models:
27
  with gr.Column(visible=model_name == default_model) as column:
28
- load_kwargs = {k: v for k, v in kwargs.items() if k not in ['src', 'choices']}
29
  from ai_gradio.providers import registry
 
30
  gr.load(name=model_name, src=registry, **load_kwargs)
31
  columns.append(column)
32
 
@@ -41,4 +43,4 @@ def get_app(
41
  for fn in demo.fns.values():
42
  fn.api_name = False
43
 
44
- return demo
 
1
  import gradio as gr
2
 
3
+
4
  def get_app(
5
  models: list[str],
6
  default_model: str,
7
  dropdown_label: str = "Select Hyperbolic Model",
8
+ choices: list[str] | None = None,
9
+ **kwargs, # noqa: ANN003
10
  ) -> gr.Blocks:
11
  display_choices = choices if choices is not None else models
12
+
13
  def update_model(new_model: str) -> list[gr.Column]:
14
  if choices is not None:
15
  idx = display_choices.index(new_model)
 
18
 
19
  with gr.Blocks(fill_height=True) as demo:
20
  model = gr.Dropdown(
21
+ label=dropdown_label,
22
  choices=display_choices,
23
+ value=choices[models.index(default_model)] if choices else default_model,
24
  )
25
 
26
  columns = []
27
  for model_name in models:
28
  with gr.Column(visible=model_name == default_model) as column:
29
+ load_kwargs = {k: v for k, v in kwargs.items() if k not in ["src", "choices"]}
30
  from ai_gradio.providers import registry
31
+
32
  gr.load(name=model_name, src=registry, **load_kwargs)
33
  columns.append(column)
34
 
 
43
  for fn in demo.fns.values():
44
  fn.api_name = False
45
 
46
+ return demo