Spaces:
Runtime error
Runtime error
import streamlit as st | |
import io | |
import os | |
import yaml | |
import pyarrow | |
import tokenizers | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
# SETTING PAGE CONFIG TO WIDE MODE | |
st.set_page_config(layout="wide") | |
def from_library(): | |
from retro_reader import RetroReader | |
from retro_reader import constants as C | |
return C, RetroReader | |
C, RetroReader = from_library() | |
# https://stackoverflow.com/questions/70274841/streamlit-unhashable-typeerror-when-i-use-st-cache | |
my_hash_func = { | |
io.TextIOWrapper: lambda _: None, | |
pyarrow.lib.Buffer: lambda _: 0, | |
tokenizers.Tokenizer: lambda _: None, | |
tokenizers.AddedToken: lambda _: None | |
} | |
# @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
# def load_ko_roberta_large_model(): | |
# config_file = "configs/inference_ko_roberta_large.yaml" | |
# return RetroReader.load(config_file=config_file) | |
# @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
# def load_ko_electra_small_model(): | |
# config_file = "configs/inference_ko_electra_small.yaml" | |
# return RetroReader.load(config_file=config_file) | |
# @st.cache(hash_funcs=my_hash_func, allow_output_mutation=True) | |
# def load_en_electra_large_model(): | |
# config_file = "configs/inference_en_electra_large.yaml" | |
# return RetroReader.load(config_file=config_file) | |
def load_vi_electra_base_model(): | |
config_file = "configs/inference_vi_electra_base.yaml" | |
return RetroReader.load(config_file=config_file) | |
RETRO_READER_HOST = { | |
# "klue/roberta-large": load_ko_roberta_large_model(), | |
# "monologg/koelectra-small-v3-discriminator": load_ko_electra_small_model(), | |
"google/electra-large-discriminator": load_vi_electra_base_model(), | |
} | |
def main(): | |
st.title("Retrospective Reader Demo") | |
# st.markdown("## Model name") | |
# option = st.selectbox( | |
# label="Choose the model used in retro reader", | |
# options=( | |
# # "[ko_KR] klue/roberta-large", | |
# # "[ko_KR] monologg/koelectra-small-v3-discriminator", | |
# "[vi_XX] google/electra-large-discriminator", | |
# ), | |
# index=0, | |
# ) | |
# lang_code, model_name = option.split(" ") | |
retro_reader = load_vi_electra_base_model() | |
# retro_reader = load_model() | |
lang_prefix = "EN" | |
height = 300 | |
# retro_reader.null_score_diff_threshold = st.sidebar.slider( | |
# label="null_score_diff_threshold", | |
# min_value=-10.0, max_value=10.0, value=0.0, step=1.0, | |
# help="ma!", | |
# ) | |
# retro_reader.rear_threshold = st.sidebar.slider( | |
# label="rear_threshold", | |
# min_value=-10.0, max_value=10.0, value=0.0, step=1.0, | |
# help="ma!", | |
# ) | |
# retro_reader.n_best_size = st.sidebar.slider( | |
# label="n_best_size", | |
# min_value=1, max_value=50, value=20, step=1, | |
# help="ma!", | |
# ) | |
# retro_reader.beta1 = st.sidebar.slider( | |
# label="beta1", | |
# min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
# help="ma!", | |
# ) | |
# retro_reader.beta2 = st.sidebar.slider( | |
# label="beta2", | |
# min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
# help="ma!", | |
# ) | |
# retro_reader.best_cof = st.sidebar.slider( | |
# label="best_cof", | |
# min_value=-10.0, max_value=10.0, value=1.0, step=1.0, | |
# help="ma!", | |
# ) | |
# return_submodule_outputs = st.sidebar.checkbox('return_submodule_outputs', value=False) | |
return_submodule_outputs = False | |
st.markdown("## Demonstration") | |
with st.form(key="my_form"): | |
query = st.text_input( | |
label="Type your query", | |
value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"), | |
max_chars=None, | |
help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"), | |
) | |
context = st.text_area( | |
label="Type your context", | |
value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"), | |
height=height, | |
max_chars=None, | |
help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"), | |
) | |
submit_button = st.form_submit_button(label="Submit") | |
if submit_button: | |
with st.spinner("Please wait.."): | |
outputs = retro_reader( | |
query=query, | |
context=context, | |
return_submodule_outputs=return_submodule_outputs, | |
) | |
answer, score = outputs[0]["id-01"], outputs[1] | |
if not answer: | |
answer = "No answer" | |
st.markdown("## Results") | |
st.write(answer) | |
st.markdown("### Rear Verification Score") | |
st.json(score) | |
# if return_submodule_outputs: | |
# score_ext, nbest_preds, score_diff = outputs[2:] | |
# st.markdown("### Sketch Reader Score (score_ext)") | |
# st.json(score_ext) | |
# st.markdown("### Intensive Reader Score (score_diff)") | |
# st.json(score_diff) | |
# st.markdown("### N Best Predictions (from intensive reader)") | |
# st.json(nbest_preds) | |
if __name__ == "__main__": | |
main() |