File size: 458 Bytes
35b487a |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def predict(text):
tokenizer = AutoTokenizer.from_pretrained("username/model_name")
model = AutoModelForSequenceClassification.from_pretrained("username/model_name")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
return predicted_class_id
|