Spaces:
Runtime error
Runtime error
LE Quoc Dat
commited on
Commit
·
4e763c9
1
Parent(s):
9098515
wip
Browse files- app.py +15 -25
- llm_utils.py +40 -0
- models.yaml +10 -0
- templates/index.html +39 -0
- test_llm_utils.py +40 -0
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
from flask import Flask, request, jsonify, render_template, make_response, send_from_directory
|
2 |
-
import
|
3 |
import os
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
import base64
|
|
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
|
@@ -66,27 +67,16 @@ def open_pdf(filename):
|
|
66 |
def generate_flashcard():
|
67 |
data = request.json
|
68 |
prompt = data['prompt']
|
69 |
-
api_key = request.headers.get('X-API-Key')
|
70 |
mode = data.get('mode', 'flashcard')
|
71 |
|
72 |
-
client = anthropic.Anthropic(api_key=api_key)
|
73 |
-
|
74 |
try:
|
75 |
-
|
76 |
-
|
77 |
-
model=model,
|
78 |
-
max_tokens=1024,
|
79 |
-
messages=[
|
80 |
-
{"role": "user", "content": prompt}
|
81 |
-
]
|
82 |
-
)
|
83 |
-
|
84 |
-
content = message.content[0].text
|
85 |
print(prompt)
|
86 |
print(content)
|
87 |
|
88 |
if mode == 'language':
|
89 |
-
#
|
90 |
lines = content.split('\n')
|
91 |
word = ''
|
92 |
translation = ''
|
@@ -106,8 +96,10 @@ def generate_flashcard():
|
|
106 |
'translation': translation,
|
107 |
'answer': answer
|
108 |
}
|
109 |
-
|
|
|
110 |
elif mode == 'flashcard' or 'flashcard' in prompt.lower():
|
|
|
111 |
flashcards = []
|
112 |
current_question = ''
|
113 |
current_answer = ''
|
@@ -124,17 +116,15 @@ def generate_flashcard():
|
|
124 |
if current_question and current_answer:
|
125 |
flashcards.append({'question': current_question, 'answer': current_answer})
|
126 |
|
127 |
-
|
|
|
128 |
elif mode == 'explain' or 'explain' in prompt.lower():
|
129 |
-
#
|
130 |
-
|
|
|
131 |
else:
|
132 |
-
|
133 |
-
|
134 |
-
# Set cookie with the API key
|
135 |
-
response.set_cookie('last_working_api_key', api_key, secure=True, httponly=True, samesite='Strict')
|
136 |
-
|
137 |
-
return response
|
138 |
except Exception as e:
|
139 |
return jsonify({'error': str(e)}), 500
|
140 |
if __name__ == '__main__':
|
|
|
1 |
from flask import Flask, request, jsonify, render_template, make_response, send_from_directory
|
2 |
+
from litellm import completion
|
3 |
import os
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
import base64
|
7 |
+
from llm_utils import generate_completion
|
8 |
|
9 |
app = Flask(__name__)
|
10 |
|
|
|
67 |
def generate_flashcard():
|
68 |
data = request.json
|
69 |
prompt = data['prompt']
|
|
|
70 |
mode = data.get('mode', 'flashcard')
|
71 |
|
|
|
|
|
72 |
try:
|
73 |
+
# Use llm_utils to generate completion
|
74 |
+
content = generate_completion(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
print(prompt)
|
76 |
print(content)
|
77 |
|
78 |
if mode == 'language':
|
79 |
+
# Parse language learning format
|
80 |
lines = content.split('\n')
|
81 |
word = ''
|
82 |
translation = ''
|
|
|
96 |
'translation': translation,
|
97 |
'answer': answer
|
98 |
}
|
99 |
+
return jsonify({'flashcard': flashcard})
|
100 |
+
|
101 |
elif mode == 'flashcard' or 'flashcard' in prompt.lower():
|
102 |
+
# Parse flashcard format
|
103 |
flashcards = []
|
104 |
current_question = ''
|
105 |
current_answer = ''
|
|
|
116 |
if current_question and current_answer:
|
117 |
flashcards.append({'question': current_question, 'answer': current_answer})
|
118 |
|
119 |
+
return jsonify({'flashcards': flashcards})
|
120 |
+
|
121 |
elif mode == 'explain' or 'explain' in prompt.lower():
|
122 |
+
# Return explanation format
|
123 |
+
return jsonify({'explanation': content})
|
124 |
+
|
125 |
else:
|
126 |
+
return jsonify({'error': 'Invalid mode'}), 400
|
127 |
+
|
|
|
|
|
|
|
|
|
128 |
except Exception as e:
|
129 |
return jsonify({'error': str(e)}), 500
|
130 |
if __name__ == '__main__':
|
llm_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from litellm import completion
|
2 |
+
import yaml
|
3 |
+
import os
|
4 |
+
|
5 |
+
def load_model_config():
|
6 |
+
with open('models.yaml', 'r') as file:
|
7 |
+
return yaml.safe_load(file)
|
8 |
+
|
9 |
+
def generate_completion(prompt: str, api_key: str = None) -> str:
|
10 |
+
"""
|
11 |
+
Generate completion using LiteLLM with the configured model
|
12 |
+
|
13 |
+
Args:
|
14 |
+
prompt (str): The input prompt
|
15 |
+
api_key (str, optional): Override API key. If not provided, will use environment variable
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
str: The generated completion text
|
19 |
+
"""
|
20 |
+
config = load_model_config()
|
21 |
+
|
22 |
+
# Get the first environment variable and its models
|
23 |
+
first_env_var = list(config['models'][0].keys())[0]
|
24 |
+
model_name = config['models'][0][first_env_var][0]
|
25 |
+
|
26 |
+
# If no API key provided, get from environment
|
27 |
+
if api_key is None:
|
28 |
+
api_key = os.getenv(first_env_var)
|
29 |
+
if not api_key:
|
30 |
+
raise ValueError(f"Please set {first_env_var} environment variable")
|
31 |
+
|
32 |
+
messages = [{"role": "user", "content": prompt}]
|
33 |
+
|
34 |
+
response = completion(
|
35 |
+
model=model_name,
|
36 |
+
messages=messages,
|
37 |
+
api_key=api_key
|
38 |
+
)
|
39 |
+
|
40 |
+
return response.choices[0].message.content
|
models.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
- GEMINI_API_KEY:
|
3 |
+
- "gemini/gemini-exp-1206"
|
4 |
+
|
5 |
+
- OPENROUTER_API_KEY:
|
6 |
+
- "openrouter/google/gemini-exp-1206:free"
|
7 |
+
- "openrouter/anthropic/claude-3-haiku-20240307"
|
8 |
+
- "openrouter/anthropic/claude-3-sonnet-20240229"
|
9 |
+
|
10 |
+
|
templates/index.html
CHANGED
@@ -9,6 +9,45 @@
|
|
9 |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jszip/3.1.5/jszip.min.js"></script>
|
10 |
<script src="https://cdn.jsdelivr.net/npm/epubjs/dist/epub.min.js"></script>
|
11 |
<link rel="stylesheet" href="/static/css/styles.css">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
</head>
|
13 |
|
14 |
<body>
|
|
|
9 |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jszip/3.1.5/jszip.min.js"></script>
|
10 |
<script src="https://cdn.jsdelivr.net/npm/epubjs/dist/epub.min.js"></script>
|
11 |
<link rel="stylesheet" href="/static/css/styles.css">
|
12 |
+
<style>
|
13 |
+
.api-settings {
|
14 |
+
margin-bottom: 15px;
|
15 |
+
padding: 10px;
|
16 |
+
background: #f5f5f5;
|
17 |
+
border-radius: 5px;
|
18 |
+
}
|
19 |
+
|
20 |
+
.model-group {
|
21 |
+
margin-bottom: 10px;
|
22 |
+
}
|
23 |
+
|
24 |
+
#model-select {
|
25 |
+
width: 100%;
|
26 |
+
padding: 8px;
|
27 |
+
margin-bottom: 10px;
|
28 |
+
border: 1px solid #ddd;
|
29 |
+
border-radius: 4px;
|
30 |
+
}
|
31 |
+
|
32 |
+
#custom-model-inputs {
|
33 |
+
margin-top: 10px;
|
34 |
+
}
|
35 |
+
|
36 |
+
#custom-model-inputs input {
|
37 |
+
width: 100%;
|
38 |
+
padding: 8px;
|
39 |
+
margin-bottom: 5px;
|
40 |
+
border: 1px solid #ddd;
|
41 |
+
border-radius: 4px;
|
42 |
+
}
|
43 |
+
|
44 |
+
.default-api-key input {
|
45 |
+
width: 100%;
|
46 |
+
padding: 8px;
|
47 |
+
border: 1px solid #ddd;
|
48 |
+
border-radius: 4px;
|
49 |
+
}
|
50 |
+
</style>
|
51 |
</head>
|
52 |
|
53 |
<body>
|
test_llm_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
from llm_utils import generate_completion
|
4 |
+
|
5 |
+
def test_generate_completion():
|
6 |
+
# Load config to get the environment variable name
|
7 |
+
with open('models.yaml', 'r') as file:
|
8 |
+
config = yaml.safe_load(file)
|
9 |
+
|
10 |
+
# Get the first environment variable name
|
11 |
+
env_var_name = list(config['models'][0].keys())[0]
|
12 |
+
|
13 |
+
# Get API key from environment variable
|
14 |
+
api_key = os.getenv(env_var_name)
|
15 |
+
if not api_key:
|
16 |
+
raise ValueError(f"Please set {env_var_name} environment variable")
|
17 |
+
|
18 |
+
# Test prompt
|
19 |
+
test_prompt = "What is 2+2? Answer in one word."
|
20 |
+
|
21 |
+
try:
|
22 |
+
# Test with explicit API key
|
23 |
+
response = generate_completion(test_prompt, api_key)
|
24 |
+
print(f"Test prompt: {test_prompt}")
|
25 |
+
print(f"Response with explicit API key: {response}")
|
26 |
+
assert isinstance(response, str)
|
27 |
+
assert len(response) > 0
|
28 |
+
|
29 |
+
# Test with environment variable
|
30 |
+
response = generate_completion(test_prompt)
|
31 |
+
print(f"Response with environment variable: {response}")
|
32 |
+
assert isinstance(response, str)
|
33 |
+
assert len(response) > 0
|
34 |
+
|
35 |
+
print("Test passed successfully!")
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Test failed with error: {str(e)}")
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
test_generate_completion()
|