import json
from typing import List

import cv2
import os

import numpy as np
import gradio as gr
import supervision as sv
from inference_sdk import (
    InferenceHTTPClient,
    InferenceConfiguration,
    VisualisationResponseFormat
)


def read_json_file(file_path: str) -> dict:
    with open(file_path, 'r') as file:
        return json.load(file)


def split_and_strip(text: str) -> List[str]:
    return [part.strip() for part in text.split(',')]


MARKDOWN = """
# WORKFLOWS 🛠
Define complex ML pipelines in JSON and execute it, running multiple models, fusing 
outputs seamlessly.
Use self-hosted Inference HTTP [container](https://inference.khulnasoft.com/inference_helpers/inference_cli/#inference-server-start) 
or run against KhulnaSoft [API](https://detect.khulnasoft.com/docs)
to get results without single line of code written.
"""

# LICENSE PLATES WORKFLOW

LICENSE_PLATES_MARKDOWN = """
![license-plates-detection-workflow](
https://media.khulnasoft.com/inference/license-plates-detection-workflow.png)
"""
LICENSE_PLATES_EXAMPLES = [
    "https://media.khulnasoft.com/inference/license_plate_1.jpg",
    "https://media.khulnasoft.com/inference/license_plate_2.jpg",
]
LICENSE_PLATES_SPECIFICATION_PATH = 'configs/license_plates.json'
LICENSE_PLATES_SPECIFICATION = read_json_file(LICENSE_PLATES_SPECIFICATION_PATH)
LICENSE_PLATES_SPECIFICATION_STRING = f"""
```json
{json.dumps(LICENSE_PLATES_SPECIFICATION, indent=4)}
```
"""

# CAR BRAND WORKFLOW

CAR_BRANDS_MARKDOWN = """
![car-brand-workflow](
https://media.khulnasoft.com/inference/car-brand-workflow.png.png)
"""
CAR_BRANDS_EXAMPLES = [
    ["Lexus, Honda, Seat", "https://media.khulnasoft.com/inference/multiple_cars_1.jpg"],
    ["Volkswagen, Renault, Mercedes", "https://media.khulnasoft.com/inference/multiple_cars_2.jpg"],
]
CAR_BRANDS_SPECIFICATION_PATH = 'configs/car_brands.json'
CAR_BRANDS_SPECIFICATION = read_json_file(CAR_BRANDS_SPECIFICATION_PATH)
CAR_BRANDS_SPECIFICATION_STRING = f"""
```json
{json.dumps(CAR_BRANDS_SPECIFICATION, indent=4)}
```
"""

API_URL = os.getenv('API_URL', None)
API_KEY = os.getenv('API_KEY', None)
print("API_URL", API_URL)

if API_KEY is None or API_URL is None:
    raise ValueError("API_URL and API_KEY environment variables are required")


CLIENT = InferenceHTTPClient(api_url=API_URL, api_key=API_KEY)

CLIENT.configure(InferenceConfiguration(
    output_visualisation_format=VisualisationResponseFormat.NUMPY))


def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
    h, w, _ = image.shape
    annotated_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    line_thickness = sv.calculate_dynamic_line_thickness(resolution_wh=(w, h))
    text_scale = sv.calculate_dynamic_text_scale(resolution_wh=(w, h))
    bounding_box_annotator = sv.BoundingBoxAnnotator(thickness=line_thickness)
    label_annotator = sv.LabelAnnotator(
        text_scale=text_scale,
        text_thickness=line_thickness
    )
    annotated_image = bounding_box_annotator.annotate(
        annotated_image, detections)
    annotated_image = label_annotator.annotate(
        annotated_image, detections)
    return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)


def inference_license_plates(input_image: np.ndarray) -> np.ndarray:
    result = CLIENT.infer_from_workflow(
        specification=LICENSE_PLATES_SPECIFICATION["specification"],
        images={"image": input_image},
    )
    detections = sv.Detections.from_inference(result)
    if len(detections) == 0:
        return input_image

    detections['class_name'] = (
        result["recognised_plates"]
        if isinstance(result["recognised_plates"], list)
        else [result["recognised_plates"]]
    )
    return annotate_image(input_image, detections)


def inference_car_brands(input_text: str, input_image: np.ndarray) -> np.ndarray:
    classes = split_and_strip(input_text)
    result = CLIENT.infer_from_workflow(
        specification=CAR_BRANDS_SPECIFICATION["specification"],
        images={"image": input_image},
        parameters={"car_types": classes}
    )

    detections = sv.Detections.from_inference(result)
    if len(detections) == 0:
        return input_image

    if len(detections) > 1:
        class_ids = np.argmax(result["clip"], axis=1)
    else:
        class_ids = np.array([np.argmax(result["clip"], axis=0)])

    detections.class_ids = class_ids
    detections['class_name'] = [classes[class_id] for class_id in class_ids]

    return annotate_image(input_image, detections)


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Tab(label="License Plates"):
        gr.Markdown(LICENSE_PLATES_MARKDOWN)
        with gr.Accordion("Configuration JSON", open=False):
            gr.Markdown(LICENSE_PLATES_SPECIFICATION_STRING)
        with gr.Row():
            license_plates_input_image_component = gr.Image(
                type='numpy',
                label='Input Image'
            )
            license_plates_output_image_component = gr.Image(
                type='numpy',
                label='Output Image'
            )
        with gr.Row():
            license_plates_submit_button_component = gr.Button('Submit')
        gr.Examples(
            fn=inference_license_plates,
            examples=LICENSE_PLATES_EXAMPLES,
            inputs=license_plates_input_image_component,
            outputs=license_plates_output_image_component,
            cache_examples=True
        )
    with gr.Tab(label="Car Brands"):
        gr.Markdown(CAR_BRANDS_MARKDOWN)
        with gr.Accordion("Configuration JSON", open=False):
            gr.Markdown(CAR_BRANDS_SPECIFICATION_STRING)
        with gr.Row():
            with gr.Column():
                car_brands_input_image_component = gr.Image(
                    type='numpy',
                    label='Input Image'
                )
                car_brands_input_text = gr.Textbox(
                    label='Car Brands',
                    placeholder='Enter car brands separated by comma'
                )
            car_brands_output_image_component = gr.Image(
                type='numpy',
                label='Output Image'
            )
        with gr.Row():
            car_brands_submit_button_component = gr.Button('Submit')
        gr.Examples(
            fn=inference_car_brands,
            examples=CAR_BRANDS_EXAMPLES,
            inputs=[car_brands_input_text, car_brands_input_image_component],
            outputs=car_brands_output_image_component,
            cache_examples=True
        )

        license_plates_submit_button_component.click(
            fn=inference_license_plates,
            inputs=license_plates_input_image_component,
            outputs=license_plates_output_image_component
        )
        car_brands_submit_button_component.click(
            fn=inference_car_brands,
            inputs=[car_brands_input_text, car_brands_input_image_component],
            outputs=car_brands_output_image_component
        )

demo.launch(debug=False, show_error=True)