sanshizhang commited on
Commit
694248a
·
verified ·
1 Parent(s): d18b1c0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -0
README.md CHANGED
@@ -1,3 +1,83 @@
1
  ---
2
  license: openrail
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: openrail
3
  ---
4
+ 验证集准确度: 0.9382193411826961
5
+ 验证集分类报告:
6
+ precision recall f1-score support
7
+
8
+ negative 0.93 0.95 0.94 3785
9
+ positive 0.95 0.96 0.95 6919
10
+ neutral 0.93 0.89 0.91 4414
11
+
12
+ accuracy 0.94 15118
13
+ macro avg 0.94 0.93 0.93 15118
14
+ weighted avg 0.94 0.94 0.94 15118
15
+
16
+ 大概使用了10w+的数据做了一个基金方面的中文情感分析模型,暂时测试下来还可以,负面方面的文本是有专人处理过的,中性的可能不准确。
17
+ 0: 'negative', 1: 'positive', 2: 'neutral'
18
+
19
+
20
+ 测试代码如下:
21
+ import sys
22
+ import re
23
+ import torch
24
+ from transformers import BertTokenizer, BertForSequenceClassification
25
+ from torch.nn.functional import softmax
26
+
27
+ # 设定使用CPU或CUDA
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+
30
+ # 载入预先保存的模型和分词器
31
+ model = BertForSequenceClassification.from_pretrained('sanshizhang/Chinese-Sentiment-Analysis-Fund-Direction')
32
+ tokenizer = BertTokenizer.from_pretrained('sanshizhang/Chinese-Sentiment-Analysis-Fund-Direction')
33
+
34
+ # 确保模型在正确的设备上
35
+ model = model.to(device)
36
+ model.eval() # 把模型设置为评估模式
37
+
38
+ # 函数定义:进行预测并返回预测概率
39
+ def predict_sentiment(text):
40
+ # 编码文本数据
41
+ encoding = tokenizer.encode_plus(
42
+ text,
43
+ max_length=512,
44
+ add_special_tokens=True,
45
+ return_token_type_ids=False,
46
+ padding='max_length', # 修改此处
47
+ return_attention_mask=True,
48
+ return_tensors='pt',
49
+ truncation=True
50
+ )
51
+
52
+ # ... 其他代码不变
53
+
54
+ # 取出输入对应的编码
55
+ input_ids = encoding['input_ids'].to(device)
56
+ attention_mask = encoding['attention_mask'].to(device)
57
+
58
+ # 不计算梯度
59
+ with torch.no_grad():
60
+ # 产生情感预测的logits
61
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
62
+
63
+ # 使用softmax将logits转换为概率
64
+ probs = softmax(outputs.logits, dim=1)
65
+
66
+ # 返回概率和预测的类别
67
+ return probs, torch.argmax(probs, dim=1).cpu().numpy()[0]
68
+
69
+ # 从命令行参数获取文本,合并并清理特殊字符
70
+ arguments = sys.argv[1:] # 忽略脚本名称
71
+ text = ' '.join(arguments) # 合并为单一字符串
72
+ text = re.sub(r"[^\u4e00-\u9fff\d.a-zA-Z%+\-。!?,、;:()【】《》“”‘’]", '', text) # 去除特殊字符
73
+
74
+ # print(f"传过来的文本是: {text}")
75
+ # 进行预测
76
+ probabilities, prediction = predict_sentiment(text)
77
+
78
+ sentiment_labels = {0: 'negative', 1: 'positive', 2: 'neutral'}
79
+
80
+ # 打印出预测的情感及其概率
81
+ predicted_sentiment = sentiment_labels[prediction]
82
+ print(f"Predicted sentiment: {predicted_sentiment},Probability:{probabilities[0][prediction].item()}")
83
+ # print(f"Probability: {probabilities[0][prediction].item()}")