UnityPaul commited on
Commit
833095d
·
verified ·
1 Parent(s): bbe03ab

Upload RunPhi15.cs

Browse files
Files changed (1) hide show
  1. RunPhi15.cs +43 -56
RunPhi15.cs CHANGED
@@ -3,18 +3,18 @@ using System.Collections.Generic;
3
  using UnityEngine;
4
  using Unity.Sentis;
5
  using System.IO;
6
- using Newtonsoft.Json;
7
  using System.Text;
 
8
 
9
  /*
10
- * Phi 1.5 Inference Code
11
- * =======================
12
  *
13
  * Put this script on the Main Camera
14
  *
15
  * In Assets/StreamingAssets put:
16
  *
17
- * phi15.sentis
18
  * vocab.json
19
  * merges.txt
20
  *
@@ -24,26 +24,24 @@ using System.Text;
24
  */
25
 
26
 
27
- public class RunPhi15 : MonoBehaviour
28
  {
 
 
29
  const BackendType backend = BackendType.GPUCompute;
30
- //string outputString = "Question: \"What is the capital of France?\"\n Correct answer: \"";
31
- //string outputString = "The human asked, \"What is your favourite animal?\" so the wise man answered correctly, \"";
32
- string outputString = "Once upon a time, there were three";
33
- //string outputString = "// Javascript function for prime numbers";
34
 
35
  // This is how many tokens you want. It can be adjusted.
36
  const int maxTokens = 100;
37
 
38
  //Make this smaller for more randomness
39
- const float predictability = 5;
40
 
41
  //Special tokens
42
  const int END_OF_TEXT = 50256;
43
 
44
- Ops ops;
45
- ITensorAllocator allocator;
46
-
47
  //Store the vocabulary
48
  string[] tokens;
49
 
@@ -60,7 +58,7 @@ public class RunPhi15 : MonoBehaviour
60
 
61
 
62
  //stop after this many tokens
63
- const int stopAfter = 200;
64
 
65
  int totalTokens = 0;
66
 
@@ -69,23 +67,28 @@ public class RunPhi15 : MonoBehaviour
69
 
70
  void Start()
71
  {
72
- allocator = new TensorCachingAllocator();
73
- ops = WorkerFactory.CreateOps(backend, allocator);
74
-
75
  SetupWhiteSpaceShifts();
76
 
77
  LoadVocabulary();
78
-
79
- Model model = ModelLoader.Load(Application.streamingAssetsPath + "/phi15.sentis");
80
- engine = WorkerFactory.CreateWorker(backend, model);
81
 
82
- GO(outputString);
83
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- public void GO(string text)
86
- {
87
- outputString = text;
88
  DecodePrompt(outputString);
 
89
  runInference = true;
90
  }
91
 
@@ -101,18 +104,18 @@ public class RunPhi15 : MonoBehaviour
101
  void RunInference()
102
  {
103
  using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
 
104
 
105
- engine.Execute(tokensSoFar);
106
 
107
- var tokensOut = engine.PeekOutput() as TensorFloat;
 
108
 
109
- using var row = ops.Slice(tokensOut, new[] { currentToken }, new[] { currentToken + 1 }, new[] { 1 }, new[] { 1 });
110
- using var rowB = ops.Mul(predictability, row);
111
- using var probs = ops.Softmax(rowB, 2);
112
- probs.MakeReadable();
113
 
114
- int ID = SelectRandomToken(probs.ToReadOnlyArray());
115
 
 
116
  if (currentToken >= maxTokens - 1)
117
  {
118
  for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
@@ -134,6 +137,7 @@ public class RunPhi15 : MonoBehaviour
134
  else outputString += GetUnicodeText(tokens[ID]);
135
 
136
  Debug.Log(outputString);
 
137
  }
138
 
139
  void DecodePrompt(string text)
@@ -147,10 +151,9 @@ public class RunPhi15 : MonoBehaviour
147
  currentToken = inputTokens.Count - 1;
148
  }
149
 
150
-
151
  void LoadVocabulary()
152
  {
153
- var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json");
154
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
155
  tokens = new string[vocab.Count];
156
  foreach (var item in vocab)
@@ -158,23 +161,7 @@ public class RunPhi15 : MonoBehaviour
158
  tokens[item.Value] = item.Key;
159
  }
160
 
161
- merges = File.ReadAllLines(Application.streamingAssetsPath + "/merges.txt");
162
- }
163
-
164
-
165
- int SelectRandomToken(float[] probs)
166
- {
167
- float p = UnityEngine.Random.Range(0, 1f);
168
- float t = 0;
169
- for (int i = 0; i < probs.Length; i++)
170
- {
171
- t += probs[i];
172
- if (p < t)
173
- {
174
- return i;
175
- }
176
- }
177
- return probs.Length - 1;
178
  }
179
 
180
  // Translates encoded special characters to Unicode
@@ -215,7 +202,7 @@ public class RunPhi15 : MonoBehaviour
215
  for (int i = 0, n = 0; i < 256; i++)
216
  {
217
  encodedCharacters[i] = i;
218
- if (IsWhiteSpace((char)i))
219
  {
220
  encodedCharacters[i] = n + 256;
221
  whiteSpaceCharacters[n++] = i;
@@ -223,9 +210,10 @@ public class RunPhi15 : MonoBehaviour
223
  }
224
  }
225
 
226
- bool IsWhiteSpace(char c)
227
  {
228
- return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�'));
 
229
  }
230
 
231
  List<int> GetTokens(string text)
@@ -276,7 +264,6 @@ public class RunPhi15 : MonoBehaviour
276
  private void OnDestroy()
277
  {
278
  engine?.Dispose();
279
- ops?.Dispose();
280
- allocator?.Dispose();
281
  }
 
282
  }
 
3
  using UnityEngine;
4
  using Unity.Sentis;
5
  using System.IO;
 
6
  using System.Text;
7
+ using FF = Unity.Sentis.Functional;
8
 
9
  /*
10
+ * Phi1.5 Inference Code
11
+ * ===========================
12
  *
13
  * Put this script on the Main Camera
14
  *
15
  * In Assets/StreamingAssets put:
16
  *
17
+ * phi15.sentis (or put in asset folder)
18
  * vocab.json
19
  * merges.txt
20
  *
 
24
  */
25
 
26
 
27
+ public class RunPhi15: MonoBehaviour
28
  {
29
+ //Drop the tinystories.sentis or onnx file on here if using an asset:
30
+ //public ModelAsset asset;
31
  const BackendType backend = BackendType.GPUCompute;
32
+
33
+ //string outputString = "Once upon a time, there were three bears";
34
+ string outputString = "One day an alien came down from Mars. It saw a chicken";
 
35
 
36
  // This is how many tokens you want. It can be adjusted.
37
  const int maxTokens = 100;
38
 
39
  //Make this smaller for more randomness
40
+ const float predictability = 5f;
41
 
42
  //Special tokens
43
  const int END_OF_TEXT = 50256;
44
 
 
 
 
45
  //Store the vocabulary
46
  string[] tokens;
47
 
 
58
 
59
 
60
  //stop after this many tokens
61
+ const int stopAfter = 100;
62
 
63
  int totalTokens = 0;
64
 
 
67
 
68
  void Start()
69
  {
 
 
 
70
  SetupWhiteSpaceShifts();
71
 
72
  LoadVocabulary();
 
 
 
73
 
74
+ var model1 = ModelLoader.Load(Path.Join(Application.streamingAssetsPath , "phi15.sentis"));
75
+
76
+ int outputIndex = model1.outputs.Count - 1;
77
+ //var model1 = ModelLoader.Load(asset);
78
+ //Create a new model to select the random token:
79
+ var model2 = FF.Compile(
80
+ (input, currentToken) =>
81
+ {
82
+ var row = FF.Select(model1.Forward(input)[outputIndex], 1, currentToken);
83
+ return FF.Multinomial(predictability * row, 1);
84
+ },
85
+ (model1.inputs[0], InputDef.Int(new TensorShape()))
86
+ );
87
+
88
+ engine = WorkerFactory.CreateWorker(backend, model2);
89
 
 
 
 
90
  DecodePrompt(outputString);
91
+
92
  runInference = true;
93
  }
94
 
 
104
  void RunInference()
105
  {
106
  using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
107
+ using var index = new TensorInt(currentToken);
108
 
109
+ engine.Execute(new Dictionary<string, Tensor> { {"input_0", tokensSoFar }, { "input_1", index }});
110
 
111
+ var probs = engine.PeekOutput() as TensorInt;
112
+ //Debug.Log(probs.shape);
113
 
114
+ probs.CompleteOperationsAndDownload();
 
 
 
115
 
116
+ int ID = probs[0];
117
 
118
+ //shift window down if got to the end
119
  if (currentToken >= maxTokens - 1)
120
  {
121
  for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
 
137
  else outputString += GetUnicodeText(tokens[ID]);
138
 
139
  Debug.Log(outputString);
140
+
141
  }
142
 
143
  void DecodePrompt(string text)
 
151
  currentToken = inputTokens.Count - 1;
152
  }
153
 
 
154
  void LoadVocabulary()
155
  {
156
+ var jsonText = File.ReadAllText(Path.Join(Application.streamingAssetsPath , "vocab.json"));
157
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
158
  tokens = new string[vocab.Count];
159
  foreach (var item in vocab)
 
161
  tokens[item.Value] = item.Key;
162
  }
163
 
164
+ merges = File.ReadAllLines(Path.Join(Application.streamingAssetsPath , "merges.txt"));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  }
166
 
167
  // Translates encoded special characters to Unicode
 
202
  for (int i = 0, n = 0; i < 256; i++)
203
  {
204
  encodedCharacters[i] = i;
205
+ if (IsWhiteSpace(i))
206
  {
207
  encodedCharacters[i] = n + 256;
208
  whiteSpaceCharacters[n++] = i;
 
210
  }
211
  }
212
 
213
+ bool IsWhiteSpace(int i)
214
  {
215
+ //returns true if it is a whitespace character
216
+ return i <= 32 || (i >= 127 && i <= 160) || i == 173;
217
  }
218
 
219
  List<int> GetTokens(string text)
 
264
  private void OnDestroy()
265
  {
266
  engine?.Dispose();
 
 
267
  }
268
+
269
  }