Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,83 @@
|
|
1 |
---
|
2 |
license: openrail
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: openrail
|
3 |
---
|
4 |
+
验证集准确度: 0.9382193411826961
|
5 |
+
验证集分类报告:
|
6 |
+
precision recall f1-score support
|
7 |
+
|
8 |
+
negative 0.93 0.95 0.94 3785
|
9 |
+
positive 0.95 0.96 0.95 6919
|
10 |
+
neutral 0.93 0.89 0.91 4414
|
11 |
+
|
12 |
+
accuracy 0.94 15118
|
13 |
+
macro avg 0.94 0.93 0.93 15118
|
14 |
+
weighted avg 0.94 0.94 0.94 15118
|
15 |
+
|
16 |
+
大概使用了10w+的数据做了一个基金方面的中文情感分析模型,暂时测试下来还可以,负面方面的文本是有专人处理过的,中性的可能不准确。
|
17 |
+
0: 'negative', 1: 'positive', 2: 'neutral'
|
18 |
+
|
19 |
+
|
20 |
+
测试代码如下:
|
21 |
+
import sys
|
22 |
+
import re
|
23 |
+
import torch
|
24 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
25 |
+
from torch.nn.functional import softmax
|
26 |
+
|
27 |
+
# 设定使用CPU或CUDA
|
28 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
29 |
+
|
30 |
+
# 载入预先保存的模型和分词器
|
31 |
+
model = BertForSequenceClassification.from_pretrained('sanshizhang/Chinese-Sentiment-Analysis-Fund-Direction')
|
32 |
+
tokenizer = BertTokenizer.from_pretrained('sanshizhang/Chinese-Sentiment-Analysis-Fund-Direction')
|
33 |
+
|
34 |
+
# 确保模型在正确的设备上
|
35 |
+
model = model.to(device)
|
36 |
+
model.eval() # 把模型设置为评估模式
|
37 |
+
|
38 |
+
# 函数定义:进行预测并返回预测概率
|
39 |
+
def predict_sentiment(text):
|
40 |
+
# 编码文本数据
|
41 |
+
encoding = tokenizer.encode_plus(
|
42 |
+
text,
|
43 |
+
max_length=512,
|
44 |
+
add_special_tokens=True,
|
45 |
+
return_token_type_ids=False,
|
46 |
+
padding='max_length', # 修改此处
|
47 |
+
return_attention_mask=True,
|
48 |
+
return_tensors='pt',
|
49 |
+
truncation=True
|
50 |
+
)
|
51 |
+
|
52 |
+
# ... 其他代码不变
|
53 |
+
|
54 |
+
# 取出输入对应的编码
|
55 |
+
input_ids = encoding['input_ids'].to(device)
|
56 |
+
attention_mask = encoding['attention_mask'].to(device)
|
57 |
+
|
58 |
+
# 不计算梯度
|
59 |
+
with torch.no_grad():
|
60 |
+
# 产生情感预测的logits
|
61 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
62 |
+
|
63 |
+
# 使用softmax将logits转换为概率
|
64 |
+
probs = softmax(outputs.logits, dim=1)
|
65 |
+
|
66 |
+
# 返回概率和预测的类别
|
67 |
+
return probs, torch.argmax(probs, dim=1).cpu().numpy()[0]
|
68 |
+
|
69 |
+
# 从命令行参数获取文本,合并并清理特殊字符
|
70 |
+
arguments = sys.argv[1:] # 忽略脚本名称
|
71 |
+
text = ' '.join(arguments) # 合并为单一字符串
|
72 |
+
text = re.sub(r"[^\u4e00-\u9fff\d.a-zA-Z%+\-。!?,、;:()【】《》“”‘’]", '', text) # 去除特殊字符
|
73 |
+
|
74 |
+
# print(f"传过来的文本是: {text}")
|
75 |
+
# 进行预测
|
76 |
+
probabilities, prediction = predict_sentiment(text)
|
77 |
+
|
78 |
+
sentiment_labels = {0: 'negative', 1: 'positive', 2: 'neutral'}
|
79 |
+
|
80 |
+
# 打印出预测的情感及其概率
|
81 |
+
predicted_sentiment = sentiment_labels[prediction]
|
82 |
+
print(f"Predicted sentiment: {predicted_sentiment},Probability:{probabilities[0][prediction].item()}")
|
83 |
+
# print(f"Probability: {probabilities[0][prediction].item()}")
|