daniild71r commited on
Commit
b9f69ed
·
1 Parent(s): f37dabc

app created

Browse files
app.py CHANGED
@@ -1,4 +1,92 @@
 
1
  import streamlit as st
 
 
2
 
3
- st.markdown('### You are gay, man.')
4
- st.markdown('<img src=\'https://sun9-76.userapi.com/impg/Wsd9lR42hY-8Hl_u5sAuCuxAJZ_OoXZMz8XbGA/jGaZN7X3UrU.jpg?size=1368x1080&quality=96&sign=c64d55f33cb98694f6514df99f81c732&type=album\' width=\'30%\'>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import streamlit as st
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
4
+ from tokenizers import Tokenizer
5
 
6
+
7
+ def fake_hash(x):
8
+ return 0
9
+
10
+
11
+ @st.cache(hash_funcs={Tokenizer: fake_hash}, suppress_st_warning=True, allow_output_mutation=True)
12
+ def initialize():
13
+ model_name = 'distilbert-base-cased'
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForSequenceClassification.from_pretrained('./final_model')
16
+
17
+ the_pipeline = TextClassificationPipeline(
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ return_all_scores=True,
21
+ device=-1
22
+ )
23
+
24
+ cat_mapping_file = open('cat_mapping.json', 'r')
25
+ cat_name_mapping_file = open('cat_name_mapping.json', 'r')
26
+ cat_mapping = json.load(cat_mapping_file)
27
+ cat_name_mapping = json.load(cat_name_mapping_file)
28
+
29
+ return the_pipeline, cat_mapping, cat_name_mapping
30
+
31
+
32
+ def get_top(the_pipeline, cat_mapping, title, summary, thresh=0.95):
33
+ if title == '' or summary == '':
34
+ return 'Not enough data to compute.'
35
+
36
+ question = title + ' || ' + summary
37
+ if len(question) > 4000:
38
+ return 'Your input is supsiciously long, try something shorter.'
39
+
40
+ try:
41
+ result = the_pipeline(question)[0]
42
+ result.sort(key=lambda x: -x['score'])
43
+
44
+ current_sum = 0
45
+ scores = []
46
+
47
+ for score in result:
48
+ scores.append(score)
49
+ current_sum += score['score']
50
+ if current_sum >= thresh:
51
+ break
52
+
53
+ for i in range(len(result)):
54
+ result[i]['label'] = cat_mapping[result[i]['label'][6:]]
55
+
56
+ return scores
57
+
58
+ except BaseException:
59
+ return 'Something unexpected happened, I\'m sorry. Try again.'
60
+
61
+
62
+ st.markdown('## Welcome to the CS article classification page!')
63
+ st.markdown('### What\'s below is pretty much self-explanatory.')
64
+
65
+ img_source = 'https://sun9-55.userapi.com/impg/azBQ_VTvbgEVonbL9hhFEpwyKAhjAtpVl4H2GQ/I4Vq0H6c3UM.jpg'
66
+ img_params = 'size=1200x900&quality=96&sign=f42419d9cdbf6fe55016fb002e4e85ae&type=album'
67
+ st.markdown(
68
+ f'<img src="{img_source}?{img_params}" width="70%"><br>',
69
+ unsafe_allow_html=True
70
+ )
71
+
72
+ title = st.text_input(
73
+ 'Please, insert the title of the CS article you are interested in.',
74
+ placeholder='The title (e. g. Incorporating alien technologies in CV)'
75
+ )
76
+
77
+ summary = st.text_area(
78
+ 'Now, please, insert the summary of the CS article you are interested in.',
79
+ height=250, placeholder='The summary itself.'
80
+ )
81
+
82
+ the_pipeline, cat_mapping, cat_name_mapping = initialize()
83
+ scores = get_top(the_pipeline, cat_mapping, title, summary)
84
+
85
+ if isinstance(scores, str):
86
+ st.markdown(scores)
87
+ else:
88
+ for score in scores:
89
+ percent = round(score['score'] * 100, 2)
90
+ category_short = score['label']
91
+ category_full = cat_name_mapping[category_short]
92
+ st.markdown(f'I\'m {percent}\% certain that the article is from the {category_short} category, which is "{category_full}"')
cat_mapping.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "cs.AI",
3
+ "1": "cs.AR",
4
+ "2": "cs.CC",
5
+ "3": "cs.CE",
6
+ "4": "cs.CG",
7
+ "5": "cs.CL",
8
+ "6": "cs.CR",
9
+ "7": "cs.CV",
10
+ "8": "cs.CY",
11
+ "9": "cs.DB",
12
+ "10": "cs.DC",
13
+ "11": "cs.DL",
14
+ "12": "cs.DM",
15
+ "13": "cs.DS",
16
+ "14": "cs.ET",
17
+ "15": "cs.FL",
18
+ "16": "cs.GL",
19
+ "17": "cs.GR",
20
+ "18": "cs.GT",
21
+ "19": "cs.HC",
22
+ "20": "cs.IR",
23
+ "21": "cs.IT",
24
+ "22": "cs.LG",
25
+ "23": "cs.LO",
26
+ "24": "cs.MA",
27
+ "25": "cs.MM",
28
+ "26": "cs.MS",
29
+ "27": "cs.NA",
30
+ "28": "cs.NE",
31
+ "29": "cs.NI",
32
+ "30": "cs.OH",
33
+ "31": "cs.OS",
34
+ "32": "cs.PF",
35
+ "33": "cs.PL",
36
+ "34": "cs.RO",
37
+ "35": "cs.SC",
38
+ "36": "cs.SD",
39
+ "37": "cs.SE",
40
+ "38": "cs.SI",
41
+ "39": "cs.SY"
42
+ }
cat_name_mapping.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cs.AI": "Artificial Intelligence",
3
+ "cs.AR": "Hardware Architecture",
4
+ "cs.CC": "Computational Complexity",
5
+ "cs.CE": "Computational Engineering, Finance, and Science",
6
+ "cs.CG": "Computational Geometry",
7
+ "cs.CL": "Computation and Language",
8
+ "cs.CR": "Cryptography and Security",
9
+ "cs.CV": "Computer Vision and Pattern Recognition",
10
+ "cs.CY": "Computers and Society",
11
+ "cs.DB": "Databases",
12
+ "cs.DC": "Distributed, Parallel, and Cluster Computing",
13
+ "cs.DL": "Digital Libraries",
14
+ "cs.DM": "Discrete Mathematics",
15
+ "cs.DS": "Data Structures and Algorithms",
16
+ "cs.ET": "Emerging Technologies",
17
+ "cs.FL": "Formal Languages and Automata Theory",
18
+ "cs.GL": "General Literature",
19
+ "cs.GR": "Graphics",
20
+ "cs.GT": "Computer Science and Game Theory",
21
+ "cs.HC": "Human-Computer Interaction",
22
+ "cs.IR": "Information Retrieval",
23
+ "cs.IT": "Information Theory",
24
+ "cs.LG": "Machine Learning",
25
+ "cs.LO": "Logic in Computer Science",
26
+ "cs.MA": "Multiagent Systems",
27
+ "cs.MM": "Multimedia",
28
+ "cs.MS": "Mathematical Software",
29
+ "cs.NA": "Numerical Analysis",
30
+ "cs.NE": "Neural and Evolutionary Computing",
31
+ "cs.NI": "Networking and Internet Architecture",
32
+ "cs.OH": "Other Computer Science",
33
+ "cs.OS": "Operating Systems",
34
+ "cs.PF": "Performance",
35
+ "cs.PL": "Programming Languages",
36
+ "cs.RO": "Robotics",
37
+ "cs.SC": "Symbolic Computation",
38
+ "cs.SD": "Sound",
39
+ "cs.SE": "Software Engineering",
40
+ "cs.SI": "Social and Information Networks",
41
+ "cs.SY": "Systems and Control"
42
+ }
final_model/config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "dim": 768,
8
+ "dropout": 0.1,
9
+ "hidden_dim": 3072,
10
+ "id2label": {
11
+ "0": "LABEL_0",
12
+ "1": "LABEL_1",
13
+ "2": "LABEL_2",
14
+ "3": "LABEL_3",
15
+ "4": "LABEL_4",
16
+ "5": "LABEL_5",
17
+ "6": "LABEL_6",
18
+ "7": "LABEL_7",
19
+ "8": "LABEL_8",
20
+ "9": "LABEL_9",
21
+ "10": "LABEL_10",
22
+ "11": "LABEL_11",
23
+ "12": "LABEL_12",
24
+ "13": "LABEL_13",
25
+ "14": "LABEL_14",
26
+ "15": "LABEL_15",
27
+ "16": "LABEL_16",
28
+ "17": "LABEL_17",
29
+ "18": "LABEL_18",
30
+ "19": "LABEL_19",
31
+ "20": "LABEL_20",
32
+ "21": "LABEL_21",
33
+ "22": "LABEL_22",
34
+ "23": "LABEL_23",
35
+ "24": "LABEL_24",
36
+ "25": "LABEL_25",
37
+ "26": "LABEL_26",
38
+ "27": "LABEL_27",
39
+ "28": "LABEL_28",
40
+ "29": "LABEL_29",
41
+ "30": "LABEL_30",
42
+ "31": "LABEL_31",
43
+ "32": "LABEL_32",
44
+ "33": "LABEL_33",
45
+ "34": "LABEL_34",
46
+ "35": "LABEL_35",
47
+ "36": "LABEL_36",
48
+ "37": "LABEL_37",
49
+ "38": "LABEL_38",
50
+ "39": "LABEL_39"
51
+ },
52
+ "initializer_range": 0.02,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1,
56
+ "LABEL_10": 10,
57
+ "LABEL_11": 11,
58
+ "LABEL_12": 12,
59
+ "LABEL_13": 13,
60
+ "LABEL_14": 14,
61
+ "LABEL_15": 15,
62
+ "LABEL_16": 16,
63
+ "LABEL_17": 17,
64
+ "LABEL_18": 18,
65
+ "LABEL_19": 19,
66
+ "LABEL_2": 2,
67
+ "LABEL_20": 20,
68
+ "LABEL_21": 21,
69
+ "LABEL_22": 22,
70
+ "LABEL_23": 23,
71
+ "LABEL_24": 24,
72
+ "LABEL_25": 25,
73
+ "LABEL_26": 26,
74
+ "LABEL_27": 27,
75
+ "LABEL_28": 28,
76
+ "LABEL_29": 29,
77
+ "LABEL_3": 3,
78
+ "LABEL_30": 30,
79
+ "LABEL_31": 31,
80
+ "LABEL_32": 32,
81
+ "LABEL_33": 33,
82
+ "LABEL_34": 34,
83
+ "LABEL_35": 35,
84
+ "LABEL_36": 36,
85
+ "LABEL_37": 37,
86
+ "LABEL_38": 38,
87
+ "LABEL_39": 39,
88
+ "LABEL_4": 4,
89
+ "LABEL_5": 5,
90
+ "LABEL_6": 6,
91
+ "LABEL_7": 7,
92
+ "LABEL_8": 8,
93
+ "LABEL_9": 9
94
+ },
95
+ "max_position_embeddings": 512,
96
+ "model_type": "distilbert",
97
+ "n_heads": 12,
98
+ "n_layers": 6,
99
+ "output_past": true,
100
+ "pad_token_id": 0,
101
+ "problem_type": "single_label_classification",
102
+ "qa_dropout": 0.1,
103
+ "seq_classif_dropout": 0.2,
104
+ "sinusoidal_pos_embds": false,
105
+ "tie_weights_": true,
106
+ "torch_dtype": "float32",
107
+ "transformers_version": "4.14.0",
108
+ "vocab_size": 28996
109
+ }
final_model/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba0c18f88b4a29acdd7ff9db7f997edd994d454382b0eda2c134b2b5a6022cff
3
+ size 263289073
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers