KevSun commited on
Commit
809dc9a
·
verified ·
1 Parent(s): 17e5236

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -2
README.md CHANGED
@@ -92,7 +92,8 @@ def predict_personality(model, encoded_input):
92
  model.eval() # Set the model to evaluation mode
93
  with torch.no_grad():
94
  outputs = model(**encoded_input)
95
- return outputs.logits.squeeze()
 
96
 
97
  def print_predictions(predictions, trait_names):
98
  for trait, score in zip(trait_names, predictions):
@@ -117,10 +118,11 @@ def main():
117
  predictions = predict_personality(model, encoded_input)
118
 
119
  trait_names = ["Agreeableness", "Openness", "Conscientiousness", "Extraversion", "Neuroticism"]
120
- print_predictions(predictions.numpy(), trait_names)
121
 
122
  if __name__ == "__main__":
123
  main()
 
124
  ```
125
  ```bash
126
  python script_name.py --input "Your text here"
 
92
  model.eval() # Set the model to evaluation mode
93
  with torch.no_grad():
94
  outputs = model(**encoded_input)
95
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
96
+ return predictions[0].tolist()
97
 
98
  def print_predictions(predictions, trait_names):
99
  for trait, score in zip(trait_names, predictions):
 
118
  predictions = predict_personality(model, encoded_input)
119
 
120
  trait_names = ["Agreeableness", "Openness", "Conscientiousness", "Extraversion", "Neuroticism"]
121
+ print_predictions(predictions, trait_names)
122
 
123
  if __name__ == "__main__":
124
  main()
125
+
126
  ```
127
  ```bash
128
  python script_name.py --input "Your text here"