r0ymanesco
commited on
Commit
·
cec0bf3
1
Parent(s):
dd93f3b
Update README.md
Browse files
README.md
CHANGED
@@ -11,27 +11,7 @@ To use notdiamond-0001, format your queries using the following prompt with your
|
|
11 |
``` python
|
12 |
query = "Can you write a function that counts from 1 to 10?"
|
13 |
|
14 |
-
formatted_prompt = f"""
|
15 |
-
In general, the following types of queries should get sent to GPT-3.5:
|
16 |
-
Explanation
|
17 |
-
Summarization
|
18 |
-
Writing
|
19 |
-
Informal conversation
|
20 |
-
History, government, economics, literature, social studies
|
21 |
-
Simple math questions
|
22 |
-
Simple coding questions
|
23 |
-
Simple science questions
|
24 |
-
|
25 |
-
In general, the following types of queries should get sent to GPT-4:
|
26 |
-
Advanced coding questions
|
27 |
-
Advanced math questions
|
28 |
-
Advanced science questions
|
29 |
-
Legal questions
|
30 |
-
Medical questions
|
31 |
-
Sensitive/inappropriate queries
|
32 |
-
|
33 |
-
Your job is to determine whether the following query should be sent to GPT-3.5 or GPT-4.
|
34 |
-
|
35 |
Query:
|
36 |
{query}"""
|
37 |
```
|
@@ -45,9 +25,8 @@ You can then determine the model to call as follows
|
|
45 |
id2label = {0: 'gpt-3.5', 1: 'gpt-4'}
|
46 |
tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001")
|
47 |
model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001")
|
48 |
-
max_length = self._get_max_length(model)
|
49 |
|
50 |
-
inputs = tokenizer(formatted_prompt, truncation=True, max_length=
|
51 |
logits = model(**inputs).logits
|
52 |
model_id = logits.argmax().item()
|
53 |
model_to_call = id2label[model_id]
|
|
|
11 |
``` python
|
12 |
query = "Can you write a function that counts from 1 to 10?"
|
13 |
|
14 |
+
formatted_prompt = f"""Determine whether the following query should be sent to GPT-3.5 or GPT-4.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
Query:
|
16 |
{query}"""
|
17 |
```
|
|
|
25 |
id2label = {0: 'gpt-3.5', 1: 'gpt-4'}
|
26 |
tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001")
|
27 |
model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001")
|
|
|
28 |
|
29 |
+
inputs = tokenizer(formatted_prompt, truncation=True, max_length=512, return_tensors="pt")
|
30 |
logits = model(**inputs).logits
|
31 |
model_id = logits.argmax().item()
|
32 |
model_to_call = id2label[model_id]
|