Anon4445 commited on
Commit
69fc6bf
·
1 Parent(s): 9dbbf32

create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
+
5
+ def load_model():
6
+ tokenizer = BertTokenizer.from_pretrained("BERT_GED")
7
+ model = BertForSequenceClassification.from_pretrained("BERT_GED")
8
+ return model, tokenizer
9
+
10
+ def predict(model, tokenizer, sentence):
11
+ # Tokenize sentence
12
+ encoded_dict = tokenizer.encode_plus(
13
+ sentence,
14
+ add_special_tokens=True,
15
+ max_length=64,
16
+ padding="max_length",
17
+ truncation=True,
18
+ return_attention_mask=True,
19
+ return_tensors='pt',
20
+ )
21
+ input_ids = encoded_dict['input_ids']
22
+ attention_mask = encoded_dict['attention_mask']
23
+
24
+ # Model inference
25
+ with torch.no_grad():
26
+ outputs = model(input_ids, attention_mask=attention_mask)
27
+
28
+ logits = outputs.logits
29
+ index = torch.argmax(logits, -1).item() # Get the predicted class (0 or 1)
30
+
31
+ if index == 1:
32
+ return "perfect"
33
+ else:
34
+ return "not right!!"
35
+
36
+ def main():
37
+ st.title("Grammatical Correctness Predictor")
38
+ sentence = st.text_area("Sentence to analyze:")
39
+
40
+ if st.button("Analyze"):
41
+ if sentence:
42
+ model, tokenizer = load_model()
43
+ prediction = predict(model, tokenizer, sentence)
44
+ st.write(f'"{sentence}" is grammatically {prediction}')
45
+ else:
46
+ st.warning("Please enter a sentence.")
47
+
48
+ if __name__ == "__main__":
49
+ main()