Spaces:
Sleeping
Sleeping
Commit
·
0dc9888
0
Parent(s):
Duplicate from linhj07/chatgpt-on-wechat
Browse filesCo-authored-by: Huijie Lin <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- .github/ISSUE_TEMPLATE.md +31 -0
- .github/workflows/deploy-image.yml +59 -0
- .gitignore +14 -0
- Dockerfile +3 -0
- LICENSE +19 -0
- README.md +13 -0
- app.py +82 -0
- bot/baidu/baidu_unit_bot.py +28 -0
- bot/bot.py +17 -0
- bot/bot_factory.py +32 -0
- bot/chatgpt/chat_gpt_bot.py +156 -0
- bot/chatgpt/chat_gpt_session.py +79 -0
- bot/openai/open_ai_bot.py +109 -0
- bot/openai/open_ai_image.py +38 -0
- bot/openai/open_ai_session.py +67 -0
- bot/session_manager.py +85 -0
- bridge/bridge.py +50 -0
- bridge/context.py +57 -0
- bridge/reply.py +22 -0
- channel/channel.py +41 -0
- channel/channel_factory.py +23 -0
- channel/chat_channel.py +316 -0
- channel/chat_message.py +83 -0
- channel/terminal/terminal_channel.py +31 -0
- channel/wechat/wechat_channel.py +194 -0
- channel/wechat/wechat_message.py +57 -0
- channel/wechat/wechaty_channel.py +125 -0
- channel/wechat/wechaty_message.py +85 -0
- channel/wechatmp/README.md +46 -0
- channel/wechatmp/receive.py +42 -0
- channel/wechatmp/reply.py +52 -0
- channel/wechatmp/wechatmp_channel.py +234 -0
- common/const.py +5 -0
- common/dequeue.py +33 -0
- common/expired_dict.py +42 -0
- common/log.py +20 -0
- common/singleton.py +9 -0
- common/sorted_dict.py +65 -0
- common/time_check.py +38 -0
- common/tmp_dir.py +20 -0
- common/token_bucket.py +45 -0
- config-template.json +18 -0
- config.json +18 -0
- config.py +198 -0
- docker/Dockerfile.alpine +39 -0
- docker/Dockerfile.debian +40 -0
- docker/Dockerfile.debian.latest +33 -0
- docker/Dockerfile.latest +29 -0
- docker/build.alpine.sh +16 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 前置确认
|
2 |
+
|
3 |
+
1. 网络能够访问openai接口
|
4 |
+
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
5 |
+
3. `git pull` 拉取最新代码
|
6 |
+
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
7 |
+
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
8 |
+
6. 在已有 issue 中未搜索到类似问题
|
9 |
+
7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
10 |
+
|
11 |
+
|
12 |
+
### 问题描述
|
13 |
+
|
14 |
+
> 简要说明、截图、复现步骤等,也可以是需求或想法
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
### 终端日志 (如有报错)
|
20 |
+
|
21 |
+
```
|
22 |
+
[在此处粘贴终端日志, 可在主目录下`run.log`文件中找到]
|
23 |
+
```
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
### 环境
|
28 |
+
|
29 |
+
- 操作系统类型 (Mac/Windows/Linux):
|
30 |
+
- Python版本 ( 执行 `python3 -V` ):
|
31 |
+
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
.github/workflows/deploy-image.yml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow uses actions that are not certified by GitHub.
|
2 |
+
# They are provided by a third-party and are governed by
|
3 |
+
# separate terms of service, privacy policy, and support
|
4 |
+
# documentation.
|
5 |
+
|
6 |
+
# GitHub recommends pinning actions to a commit SHA.
|
7 |
+
# To get a newer version, you will need to update the SHA.
|
8 |
+
# You can also reference a tag or branch, but the action may change without warning.
|
9 |
+
|
10 |
+
name: Create and publish a Docker image
|
11 |
+
|
12 |
+
on:
|
13 |
+
push:
|
14 |
+
branches: ['master']
|
15 |
+
create:
|
16 |
+
env:
|
17 |
+
REGISTRY: ghcr.io
|
18 |
+
IMAGE_NAME: ${{ github.repository }}
|
19 |
+
|
20 |
+
jobs:
|
21 |
+
build-and-push-image:
|
22 |
+
runs-on: ubuntu-latest
|
23 |
+
permissions:
|
24 |
+
contents: read
|
25 |
+
packages: write
|
26 |
+
|
27 |
+
steps:
|
28 |
+
- name: Checkout repository
|
29 |
+
uses: actions/checkout@v3
|
30 |
+
|
31 |
+
- name: Log in to the Container registry
|
32 |
+
uses: docker/login-action@v2
|
33 |
+
with:
|
34 |
+
registry: ${{ env.REGISTRY }}
|
35 |
+
username: ${{ github.actor }}
|
36 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
37 |
+
|
38 |
+
- name: Extract metadata (tags, labels) for Docker
|
39 |
+
id: meta
|
40 |
+
uses: docker/metadata-action@v4
|
41 |
+
with:
|
42 |
+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
43 |
+
|
44 |
+
- name: Build and push Docker image
|
45 |
+
uses: docker/build-push-action@v3
|
46 |
+
with:
|
47 |
+
context: .
|
48 |
+
push: true
|
49 |
+
file: ./docker/Dockerfile.latest
|
50 |
+
tags: ${{ steps.meta.outputs.tags }}
|
51 |
+
labels: ${{ steps.meta.outputs.labels }}
|
52 |
+
|
53 |
+
- uses: actions/delete-package-versions@v4
|
54 |
+
with:
|
55 |
+
package-name: 'chatgpt-on-wechat'
|
56 |
+
package-type: 'container'
|
57 |
+
min-versions-to-keep: 10
|
58 |
+
delete-only-untagged-versions: 'true'
|
59 |
+
token: ${{ secrets.GITHUB_TOKEN }}
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
.idea
|
3 |
+
.wechaty/
|
4 |
+
__pycache__/
|
5 |
+
venv*
|
6 |
+
*.pyc
|
7 |
+
config.json
|
8 |
+
QR.png
|
9 |
+
nohup.out
|
10 |
+
tmp
|
11 |
+
plugins.json
|
12 |
+
itchat.pkl
|
13 |
+
*.log
|
14 |
+
user_datas.pkl
|
Dockerfile
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest
|
2 |
+
|
3 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022 zhayujie
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: wechat-bot
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.19.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: linhj07/chatgpt-on-wechat
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
from config import conf, load_config
|
5 |
+
from channel import channel_factory
|
6 |
+
from common.log import logger
|
7 |
+
from plugins import *
|
8 |
+
import signal
|
9 |
+
import sys
|
10 |
+
import config
|
11 |
+
import gradio as gr
|
12 |
+
from io import BytesIO
|
13 |
+
from PIL import Image
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
thread_pool = ThreadPoolExecutor(max_workers=8)
|
16 |
+
|
17 |
+
def getImage(bytes):
|
18 |
+
bytes_stream = BytesIO(bytes)
|
19 |
+
image = Image.open(bytes_stream)
|
20 |
+
return image
|
21 |
+
|
22 |
+
def getLoginUrl():
|
23 |
+
# load config
|
24 |
+
config.load_config()
|
25 |
+
# create channel
|
26 |
+
bot = channel_factory.create_channel("wx")
|
27 |
+
thread_pool.submit(bot.startup)
|
28 |
+
while (True):
|
29 |
+
if bot.getQrCode():
|
30 |
+
return getImage(bot.getQrCode())
|
31 |
+
|
32 |
+
def sigterm_handler_wrap(_signo):
|
33 |
+
old_handler = signal.getsignal(_signo)
|
34 |
+
def func(_signo, _stack_frame):
|
35 |
+
logger.info("signal {} received, exiting...".format(_signo))
|
36 |
+
conf().save_user_datas()
|
37 |
+
return old_handler(_signo, _stack_frame)
|
38 |
+
signal.signal(_signo, func)
|
39 |
+
|
40 |
+
def run():
|
41 |
+
try:
|
42 |
+
# load config
|
43 |
+
load_config()
|
44 |
+
# ctrl + c
|
45 |
+
sigterm_handler_wrap(signal.SIGINT)
|
46 |
+
# kill signal
|
47 |
+
sigterm_handler_wrap(signal.SIGTERM)
|
48 |
+
|
49 |
+
# create channel
|
50 |
+
channel_name=conf().get('channel_type', 'wx')
|
51 |
+
if channel_name == 'wxy':
|
52 |
+
os.environ['WECHATY_LOG']="warn"
|
53 |
+
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
54 |
+
|
55 |
+
channel = channel_factory.create_channel(channel_name)
|
56 |
+
if channel_name in ['wx','wxy','wechatmp']:
|
57 |
+
PluginManager().load_plugins()
|
58 |
+
|
59 |
+
# startup channel
|
60 |
+
channel.startup()
|
61 |
+
except Exception as e:
|
62 |
+
logger.error("App startup failed!")
|
63 |
+
logger.exception(e)
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
#run()
|
67 |
+
try:
|
68 |
+
|
69 |
+
with gr.Blocks() as demo:
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column():
|
72 |
+
btn = gr.Button(value="生成二维码")
|
73 |
+
with gr.Column():
|
74 |
+
outputs=[gr.Pil()]
|
75 |
+
btn.click(getLoginUrl, outputs=outputs)
|
76 |
+
|
77 |
+
demo.launch()
|
78 |
+
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
logger.error("App startup failed!")
|
82 |
+
logger.exception(e)
|
bot/baidu/baidu_unit_bot.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from bot.bot import Bot
|
5 |
+
from bridge.reply import Reply, ReplyType
|
6 |
+
|
7 |
+
|
8 |
+
# Baidu Unit对话接口 (可用, 但能力较弱)
|
9 |
+
class BaiduUnitBot(Bot):
|
10 |
+
def reply(self, query, context=None):
|
11 |
+
token = self.get_token()
|
12 |
+
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
13 |
+
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
14 |
+
print(post_data)
|
15 |
+
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
16 |
+
response = requests.post(url, data=post_data.encode(), headers=headers)
|
17 |
+
if response:
|
18 |
+
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
19 |
+
return reply
|
20 |
+
|
21 |
+
def get_token(self):
|
22 |
+
access_key = 'YOUR_ACCESS_KEY'
|
23 |
+
secret_key = 'YOUR_SECRET_KEY'
|
24 |
+
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
25 |
+
response = requests.get(host)
|
26 |
+
if response:
|
27 |
+
print(response.json())
|
28 |
+
return response.json()['access_token']
|
bot/bot.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Auto-replay chat robot abstract class
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
from bridge.context import Context
|
7 |
+
from bridge.reply import Reply
|
8 |
+
|
9 |
+
|
10 |
+
class Bot(object):
|
11 |
+
def reply(self, query, context : Context =None) -> Reply:
|
12 |
+
"""
|
13 |
+
bot auto-reply content
|
14 |
+
:param req: received message
|
15 |
+
:return: reply content
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
bot/bot_factory.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
channel factory
|
3 |
+
"""
|
4 |
+
from common import const
|
5 |
+
|
6 |
+
|
7 |
+
def create_bot(bot_type):
|
8 |
+
"""
|
9 |
+
create a bot_type instance
|
10 |
+
:param bot_type: bot type code
|
11 |
+
:return: bot instance
|
12 |
+
"""
|
13 |
+
if bot_type == const.BAIDU:
|
14 |
+
# Baidu Unit对话接口
|
15 |
+
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
16 |
+
return BaiduUnitBot()
|
17 |
+
|
18 |
+
elif bot_type == const.CHATGPT:
|
19 |
+
# ChatGPT 网页端web接口
|
20 |
+
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
21 |
+
return ChatGPTBot()
|
22 |
+
|
23 |
+
elif bot_type == const.OPEN_AI:
|
24 |
+
# OpenAI 官方对话模型API
|
25 |
+
from bot.openai.open_ai_bot import OpenAIBot
|
26 |
+
return OpenAIBot()
|
27 |
+
|
28 |
+
elif bot_type == const.CHATGPTONAZURE:
|
29 |
+
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
30 |
+
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
31 |
+
return AzureChatGPTBot()
|
32 |
+
raise RuntimeError
|
bot/chatgpt/chat_gpt_bot.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
from bot.bot import Bot
|
4 |
+
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
5 |
+
from bot.openai.open_ai_image import OpenAIImage
|
6 |
+
from bot.session_manager import Session, SessionManager
|
7 |
+
from bridge.context import ContextType
|
8 |
+
from bridge.reply import Reply, ReplyType
|
9 |
+
from config import conf, load_config
|
10 |
+
from common.log import logger
|
11 |
+
from common.token_bucket import TokenBucket
|
12 |
+
from common.expired_dict import ExpiredDict
|
13 |
+
import openai
|
14 |
+
import openai.error
|
15 |
+
import time
|
16 |
+
|
17 |
+
# OpenAI对话模型API (可用)
|
18 |
+
class ChatGPTBot(Bot,OpenAIImage):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
# set the default api_key
|
22 |
+
openai.api_key = conf().get('open_ai_api_key')
|
23 |
+
if conf().get('open_ai_api_base'):
|
24 |
+
openai.api_base = conf().get('open_ai_api_base')
|
25 |
+
proxy = conf().get('proxy')
|
26 |
+
if proxy:
|
27 |
+
openai.proxy = proxy
|
28 |
+
if conf().get('rate_limit_chatgpt'):
|
29 |
+
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
30 |
+
|
31 |
+
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
32 |
+
|
33 |
+
def reply(self, query, context=None):
|
34 |
+
# acquire reply content
|
35 |
+
if context.type == ContextType.TEXT:
|
36 |
+
logger.info("[CHATGPT] query={}".format(query))
|
37 |
+
|
38 |
+
|
39 |
+
session_id = context['session_id']
|
40 |
+
reply = None
|
41 |
+
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
42 |
+
if query in clear_memory_commands:
|
43 |
+
self.sessions.clear_session(session_id)
|
44 |
+
reply = Reply(ReplyType.INFO, '记忆已清除')
|
45 |
+
elif query == '#清除所有':
|
46 |
+
self.sessions.clear_all_session()
|
47 |
+
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
48 |
+
elif query == '#更新配置':
|
49 |
+
load_config()
|
50 |
+
reply = Reply(ReplyType.INFO, '配置已更新')
|
51 |
+
if reply:
|
52 |
+
return reply
|
53 |
+
session = self.sessions.session_query(query, session_id)
|
54 |
+
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
55 |
+
|
56 |
+
api_key = context.get('openai_api_key')
|
57 |
+
|
58 |
+
# if context.get('stream'):
|
59 |
+
# # reply in stream
|
60 |
+
# return self.reply_text_stream(query, new_query, session_id)
|
61 |
+
|
62 |
+
reply_content = self.reply_text(session, session_id, api_key, 0)
|
63 |
+
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
64 |
+
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
65 |
+
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
66 |
+
elif reply_content["completion_tokens"] > 0:
|
67 |
+
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
68 |
+
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
69 |
+
else:
|
70 |
+
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
71 |
+
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
72 |
+
return reply
|
73 |
+
|
74 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
75 |
+
ok, retstring = self.create_img(query, 0)
|
76 |
+
reply = None
|
77 |
+
if ok:
|
78 |
+
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
79 |
+
else:
|
80 |
+
reply = Reply(ReplyType.ERROR, retstring)
|
81 |
+
return reply
|
82 |
+
else:
|
83 |
+
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
84 |
+
return reply
|
85 |
+
|
86 |
+
def compose_args(self):
|
87 |
+
return {
|
88 |
+
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
89 |
+
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
90 |
+
# "max_tokens":4096, # 回复最大的字符数
|
91 |
+
"top_p":1,
|
92 |
+
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
93 |
+
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
94 |
+
"request_timeout": conf().get('request_timeout', 60), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
95 |
+
"timeout": conf().get('request_timeout', 120), #重试超时时间,在这个时间内,将会自动重试
|
96 |
+
}
|
97 |
+
|
98 |
+
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
|
99 |
+
'''
|
100 |
+
call openai's ChatCompletion to get the answer
|
101 |
+
:param session: a conversation session
|
102 |
+
:param session_id: session id
|
103 |
+
:param retry_count: retry count
|
104 |
+
:return: {}
|
105 |
+
'''
|
106 |
+
try:
|
107 |
+
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
108 |
+
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
109 |
+
# if api_key == None, the default openai.api_key will be used
|
110 |
+
response = openai.ChatCompletion.create(
|
111 |
+
api_key=api_key, messages=session.messages, **self.compose_args()
|
112 |
+
)
|
113 |
+
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
114 |
+
return {"total_tokens": response["usage"]["total_tokens"],
|
115 |
+
"completion_tokens": response["usage"]["completion_tokens"],
|
116 |
+
"content": response.choices[0]['message']['content']}
|
117 |
+
except Exception as e:
|
118 |
+
need_retry = retry_count < 2
|
119 |
+
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
120 |
+
if isinstance(e, openai.error.RateLimitError):
|
121 |
+
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
122 |
+
result['content'] = "提问太快啦,请休息一下再问我吧"
|
123 |
+
if need_retry:
|
124 |
+
time.sleep(5)
|
125 |
+
elif isinstance(e, openai.error.Timeout):
|
126 |
+
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
127 |
+
result['content'] = "我没有收到你的消息"
|
128 |
+
if need_retry:
|
129 |
+
time.sleep(5)
|
130 |
+
elif isinstance(e, openai.error.APIConnectionError):
|
131 |
+
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
132 |
+
need_retry = False
|
133 |
+
result['content'] = "我连接不到你的网络"
|
134 |
+
else:
|
135 |
+
logger.warn("[CHATGPT] Exception: {}".format(e))
|
136 |
+
need_retry = False
|
137 |
+
self.sessions.clear_session(session_id)
|
138 |
+
|
139 |
+
if need_retry:
|
140 |
+
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
141 |
+
return self.reply_text(session, session_id, api_key, retry_count+1)
|
142 |
+
else:
|
143 |
+
return result
|
144 |
+
|
145 |
+
|
146 |
+
class AzureChatGPTBot(ChatGPTBot):
|
147 |
+
def __init__(self):
|
148 |
+
super().__init__()
|
149 |
+
openai.api_type = "azure"
|
150 |
+
openai.api_version = "2023-03-15-preview"
|
151 |
+
|
152 |
+
def compose_args(self):
|
153 |
+
args = super().compose_args()
|
154 |
+
args["engine"] = args["model"]
|
155 |
+
del(args["model"])
|
156 |
+
return args
|
bot/chatgpt/chat_gpt_session.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bot.session_manager import Session
|
2 |
+
from common.log import logger
|
3 |
+
'''
|
4 |
+
e.g. [
|
5 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
6 |
+
{"role": "user", "content": "Who won the world series in 2020?"},
|
7 |
+
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
8 |
+
{"role": "user", "content": "Where was it played?"}
|
9 |
+
]
|
10 |
+
'''
|
11 |
+
class ChatGPTSession(Session):
|
12 |
+
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
|
13 |
+
super().__init__(session_id, system_prompt)
|
14 |
+
self.model = model
|
15 |
+
self.reset()
|
16 |
+
|
17 |
+
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
18 |
+
precise = True
|
19 |
+
try:
|
20 |
+
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
21 |
+
except Exception as e:
|
22 |
+
precise = False
|
23 |
+
if cur_tokens is None:
|
24 |
+
raise e
|
25 |
+
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
26 |
+
while cur_tokens > max_tokens:
|
27 |
+
if len(self.messages) > 2:
|
28 |
+
self.messages.pop(1)
|
29 |
+
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
30 |
+
self.messages.pop(1)
|
31 |
+
if precise:
|
32 |
+
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
33 |
+
else:
|
34 |
+
cur_tokens = cur_tokens - max_tokens
|
35 |
+
break
|
36 |
+
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
37 |
+
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
38 |
+
break
|
39 |
+
else:
|
40 |
+
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
41 |
+
break
|
42 |
+
if precise:
|
43 |
+
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
44 |
+
else:
|
45 |
+
cur_tokens = cur_tokens - max_tokens
|
46 |
+
return cur_tokens
|
47 |
+
|
48 |
+
|
49 |
+
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
50 |
+
def num_tokens_from_messages(messages, model):
|
51 |
+
"""Returns the number of tokens used by a list of messages."""
|
52 |
+
import tiktoken
|
53 |
+
try:
|
54 |
+
encoding = tiktoken.encoding_for_model(model)
|
55 |
+
except KeyError:
|
56 |
+
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
57 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
58 |
+
if model == "gpt-3.5-turbo":
|
59 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
60 |
+
elif model == "gpt-4":
|
61 |
+
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
62 |
+
elif model == "gpt-3.5-turbo-0301":
|
63 |
+
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
64 |
+
tokens_per_name = -1 # if there's a name, the role is omitted
|
65 |
+
elif model == "gpt-4-0314":
|
66 |
+
tokens_per_message = 3
|
67 |
+
tokens_per_name = 1
|
68 |
+
else:
|
69 |
+
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
70 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
71 |
+
num_tokens = 0
|
72 |
+
for message in messages:
|
73 |
+
num_tokens += tokens_per_message
|
74 |
+
for key, value in message.items():
|
75 |
+
num_tokens += len(encoding.encode(value))
|
76 |
+
if key == "name":
|
77 |
+
num_tokens += tokens_per_name
|
78 |
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
79 |
+
return num_tokens
|
bot/openai/open_ai_bot.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
from bot.bot import Bot
|
4 |
+
from bot.openai.open_ai_image import OpenAIImage
|
5 |
+
from bot.openai.open_ai_session import OpenAISession
|
6 |
+
from bot.session_manager import SessionManager
|
7 |
+
from bridge.context import ContextType
|
8 |
+
from bridge.reply import Reply, ReplyType
|
9 |
+
from config import conf
|
10 |
+
from common.log import logger
|
11 |
+
import openai
|
12 |
+
import openai.error
|
13 |
+
import time
|
14 |
+
|
15 |
+
user_session = dict()
|
16 |
+
|
17 |
+
# OpenAI对话模型API (可用)
|
18 |
+
class OpenAIBot(Bot, OpenAIImage):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
openai.api_key = conf().get('open_ai_api_key')
|
22 |
+
if conf().get('open_ai_api_base'):
|
23 |
+
openai.api_base = conf().get('open_ai_api_base')
|
24 |
+
proxy = conf().get('proxy')
|
25 |
+
if proxy:
|
26 |
+
openai.proxy = proxy
|
27 |
+
|
28 |
+
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
29 |
+
|
30 |
+
def reply(self, query, context=None):
|
31 |
+
# acquire reply content
|
32 |
+
if context and context.type:
|
33 |
+
if context.type == ContextType.TEXT:
|
34 |
+
logger.info("[OPEN_AI] query={}".format(query))
|
35 |
+
session_id = context['session_id']
|
36 |
+
reply = None
|
37 |
+
if query == '#清除记忆':
|
38 |
+
self.sessions.clear_session(session_id)
|
39 |
+
reply = Reply(ReplyType.INFO, '记忆已清除')
|
40 |
+
elif query == '#清除所有':
|
41 |
+
self.sessions.clear_all_session()
|
42 |
+
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
43 |
+
else:
|
44 |
+
session = self.sessions.session_query(query, session_id)
|
45 |
+
new_query = str(session)
|
46 |
+
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
47 |
+
|
48 |
+
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
49 |
+
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
50 |
+
|
51 |
+
if total_tokens == 0 :
|
52 |
+
reply = Reply(ReplyType.ERROR, reply_content)
|
53 |
+
else:
|
54 |
+
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
55 |
+
reply = Reply(ReplyType.TEXT, reply_content)
|
56 |
+
return reply
|
57 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
58 |
+
ok, retstring = self.create_img(query, 0)
|
59 |
+
reply = None
|
60 |
+
if ok:
|
61 |
+
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
62 |
+
else:
|
63 |
+
reply = Reply(ReplyType.ERROR, retstring)
|
64 |
+
return reply
|
65 |
+
|
66 |
+
def reply_text(self, query, session_id, retry_count=0):
|
67 |
+
try:
|
68 |
+
response = openai.Completion.create(
|
69 |
+
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
70 |
+
prompt=query,
|
71 |
+
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
72 |
+
max_tokens=1200, # 回复最大的字符数
|
73 |
+
top_p=1,
|
74 |
+
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
75 |
+
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
76 |
+
stop=["\n\n\n"]
|
77 |
+
)
|
78 |
+
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
79 |
+
total_tokens = response["usage"]["total_tokens"]
|
80 |
+
completion_tokens = response["usage"]["completion_tokens"]
|
81 |
+
logger.info("[OPEN_AI] reply={}".format(res_content))
|
82 |
+
return total_tokens, completion_tokens, res_content
|
83 |
+
except Exception as e:
|
84 |
+
need_retry = retry_count < 2
|
85 |
+
result = [0,0,"我现在有点累了,等会再来吧"]
|
86 |
+
if isinstance(e, openai.error.RateLimitError):
|
87 |
+
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
88 |
+
result[2] = "提问太快啦,请休息一下再问我吧"
|
89 |
+
if need_retry:
|
90 |
+
time.sleep(5)
|
91 |
+
elif isinstance(e, openai.error.Timeout):
|
92 |
+
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
93 |
+
result[2] = "我没有收到你的消息"
|
94 |
+
if need_retry:
|
95 |
+
time.sleep(5)
|
96 |
+
elif isinstance(e, openai.error.APIConnectionError):
|
97 |
+
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
98 |
+
need_retry = False
|
99 |
+
result[2] = "我连接不到你的网络"
|
100 |
+
else:
|
101 |
+
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
102 |
+
need_retry = False
|
103 |
+
self.sessions.clear_session(session_id)
|
104 |
+
|
105 |
+
if need_retry:
|
106 |
+
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
107 |
+
return self.reply_text(query, session_id, retry_count+1)
|
108 |
+
else:
|
109 |
+
return result
|
bot/openai/open_ai_image.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import openai
|
3 |
+
import openai.error
|
4 |
+
from common.token_bucket import TokenBucket
|
5 |
+
from common.log import logger
|
6 |
+
from config import conf
|
7 |
+
|
8 |
+
# OPENAI提供的画图接口
|
9 |
+
class OpenAIImage(object):
|
10 |
+
def __init__(self):
|
11 |
+
openai.api_key = conf().get('open_ai_api_key')
|
12 |
+
if conf().get('rate_limit_dalle'):
|
13 |
+
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
14 |
+
|
15 |
+
def create_img(self, query, retry_count=0):
|
16 |
+
try:
|
17 |
+
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
18 |
+
return False, "请求太快了,请休息一下再问我吧"
|
19 |
+
logger.info("[OPEN_AI] image_query={}".format(query))
|
20 |
+
response = openai.Image.create(
|
21 |
+
prompt=query, #图片描述
|
22 |
+
n=1, #每次生成图片的数量
|
23 |
+
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
24 |
+
)
|
25 |
+
image_url = response['data'][0]['url']
|
26 |
+
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
27 |
+
return True, image_url
|
28 |
+
except openai.error.RateLimitError as e:
|
29 |
+
logger.warn(e)
|
30 |
+
if retry_count < 1:
|
31 |
+
time.sleep(5)
|
32 |
+
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
33 |
+
return self.create_img(query, retry_count+1)
|
34 |
+
else:
|
35 |
+
return False, "提问太快啦,请休息一下再问我吧"
|
36 |
+
except Exception as e:
|
37 |
+
logger.exception(e)
|
38 |
+
return False, str(e)
|
bot/openai/open_ai_session.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bot.session_manager import Session
|
2 |
+
from common.log import logger
|
3 |
+
class OpenAISession(Session):
|
4 |
+
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
|
5 |
+
super().__init__(session_id, system_prompt)
|
6 |
+
self.model = model
|
7 |
+
self.reset()
|
8 |
+
|
9 |
+
def __str__(self):
|
10 |
+
# 构造对话模型的输入
|
11 |
+
'''
|
12 |
+
e.g. Q: xxx
|
13 |
+
A: xxx
|
14 |
+
Q: xxx
|
15 |
+
'''
|
16 |
+
prompt = ""
|
17 |
+
for item in self.messages:
|
18 |
+
if item['role'] == 'system':
|
19 |
+
prompt += item['content'] + "<|endoftext|>\n\n\n"
|
20 |
+
elif item['role'] == 'user':
|
21 |
+
prompt += "Q: " + item['content'] + "\n"
|
22 |
+
elif item['role'] == 'assistant':
|
23 |
+
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
|
24 |
+
|
25 |
+
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
|
26 |
+
prompt += "A: "
|
27 |
+
return prompt
|
28 |
+
|
29 |
+
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
30 |
+
precise = True
|
31 |
+
try:
|
32 |
+
cur_tokens = num_tokens_from_string(str(self), self.model)
|
33 |
+
except Exception as e:
|
34 |
+
precise = False
|
35 |
+
if cur_tokens is None:
|
36 |
+
raise e
|
37 |
+
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
38 |
+
while cur_tokens > max_tokens:
|
39 |
+
if len(self.messages) > 1:
|
40 |
+
self.messages.pop(0)
|
41 |
+
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
42 |
+
self.messages.pop(0)
|
43 |
+
if precise:
|
44 |
+
cur_tokens = num_tokens_from_string(str(self), self.model)
|
45 |
+
else:
|
46 |
+
cur_tokens = len(str(self))
|
47 |
+
break
|
48 |
+
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
49 |
+
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
50 |
+
break
|
51 |
+
else:
|
52 |
+
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
53 |
+
break
|
54 |
+
if precise:
|
55 |
+
cur_tokens = num_tokens_from_string(str(self), self.model)
|
56 |
+
else:
|
57 |
+
cur_tokens = len(str(self))
|
58 |
+
return cur_tokens
|
59 |
+
|
60 |
+
|
61 |
+
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
62 |
+
def num_tokens_from_string(string: str, model: str) -> int:
|
63 |
+
"""Returns the number of tokens in a text string."""
|
64 |
+
import tiktoken
|
65 |
+
encoding = tiktoken.encoding_for_model(model)
|
66 |
+
num_tokens = len(encoding.encode(string,disallowed_special=()))
|
67 |
+
return num_tokens
|
bot/session_manager.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from common.expired_dict import ExpiredDict
|
2 |
+
from common.log import logger
|
3 |
+
from config import conf
|
4 |
+
|
5 |
+
class Session(object):
|
6 |
+
def __init__(self, session_id, system_prompt=None):
|
7 |
+
self.session_id = session_id
|
8 |
+
self.messages = []
|
9 |
+
if system_prompt is None:
|
10 |
+
self.system_prompt = conf().get("character_desc", "")
|
11 |
+
else:
|
12 |
+
self.system_prompt = system_prompt
|
13 |
+
|
14 |
+
# 重置会话
|
15 |
+
def reset(self):
|
16 |
+
system_item = {'role': 'system', 'content': self.system_prompt}
|
17 |
+
self.messages = [system_item]
|
18 |
+
|
19 |
+
def set_system_prompt(self, system_prompt):
|
20 |
+
self.system_prompt = system_prompt
|
21 |
+
self.reset()
|
22 |
+
|
23 |
+
def add_query(self, query):
|
24 |
+
user_item = {'role': 'user', 'content': query}
|
25 |
+
self.messages.append(user_item)
|
26 |
+
|
27 |
+
def add_reply(self, reply):
|
28 |
+
assistant_item = {'role': 'assistant', 'content': reply}
|
29 |
+
self.messages.append(assistant_item)
|
30 |
+
|
31 |
+
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
class SessionManager(object):
|
37 |
+
def __init__(self, sessioncls, **session_args):
|
38 |
+
if conf().get('expires_in_seconds'):
|
39 |
+
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
40 |
+
else:
|
41 |
+
sessions = dict()
|
42 |
+
self.sessions = sessions
|
43 |
+
self.sessioncls = sessioncls
|
44 |
+
self.session_args = session_args
|
45 |
+
|
46 |
+
def build_session(self, session_id, system_prompt=None):
|
47 |
+
'''
|
48 |
+
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
49 |
+
如果system_prompt不会空,会更新session的system_prompt并重置session
|
50 |
+
'''
|
51 |
+
if session_id not in self.sessions:
|
52 |
+
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
53 |
+
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
54 |
+
self.sessions[session_id].set_system_prompt(system_prompt)
|
55 |
+
session = self.sessions[session_id]
|
56 |
+
return session
|
57 |
+
|
58 |
+
def session_query(self, query, session_id):
|
59 |
+
session = self.build_session(session_id)
|
60 |
+
session.add_query(query)
|
61 |
+
try:
|
62 |
+
max_tokens = conf().get("conversation_max_tokens", 1000)
|
63 |
+
total_tokens = session.discard_exceeding(max_tokens, None)
|
64 |
+
logger.debug("prompt tokens used={}".format(total_tokens))
|
65 |
+
except Exception as e:
|
66 |
+
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
67 |
+
return session
|
68 |
+
|
69 |
+
def session_reply(self, reply, session_id, total_tokens = None):
|
70 |
+
session = self.build_session(session_id)
|
71 |
+
session.add_reply(reply)
|
72 |
+
try:
|
73 |
+
max_tokens = conf().get("conversation_max_tokens", 1000)
|
74 |
+
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
75 |
+
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
76 |
+
except Exception as e:
|
77 |
+
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
78 |
+
return session
|
79 |
+
|
80 |
+
def clear_session(self, session_id):
|
81 |
+
if session_id in self.sessions:
|
82 |
+
del(self.sessions[session_id])
|
83 |
+
|
84 |
+
def clear_all_session(self):
|
85 |
+
self.sessions.clear()
|
bridge/bridge.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bridge.context import Context
|
2 |
+
from bridge.reply import Reply
|
3 |
+
from common.log import logger
|
4 |
+
from bot import bot_factory
|
5 |
+
from common.singleton import singleton
|
6 |
+
from voice import voice_factory
|
7 |
+
from config import conf
|
8 |
+
from common import const
|
9 |
+
|
10 |
+
|
11 |
+
@singleton
|
12 |
+
class Bridge(object):
|
13 |
+
def __init__(self):
|
14 |
+
self.btype={
|
15 |
+
"chat": const.CHATGPT,
|
16 |
+
"voice_to_text": conf().get("voice_to_text", "openai"),
|
17 |
+
"text_to_voice": conf().get("text_to_voice", "google")
|
18 |
+
}
|
19 |
+
model_type = conf().get("model")
|
20 |
+
if model_type in ["text-davinci-003"]:
|
21 |
+
self.btype['chat'] = const.OPEN_AI
|
22 |
+
if conf().get("use_azure_chatgpt"):
|
23 |
+
self.btype['chat'] = const.CHATGPTONAZURE
|
24 |
+
self.bots={}
|
25 |
+
|
26 |
+
def get_bot(self,typename):
|
27 |
+
if self.bots.get(typename) is None:
|
28 |
+
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
29 |
+
if typename == "text_to_voice":
|
30 |
+
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
31 |
+
elif typename == "voice_to_text":
|
32 |
+
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
33 |
+
elif typename == "chat":
|
34 |
+
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
35 |
+
return self.bots[typename]
|
36 |
+
|
37 |
+
def get_bot_type(self,typename):
|
38 |
+
return self.btype[typename]
|
39 |
+
|
40 |
+
|
41 |
+
def fetch_reply_content(self, query, context : Context) -> Reply:
|
42 |
+
return self.get_bot("chat").reply(query, context)
|
43 |
+
|
44 |
+
|
45 |
+
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
46 |
+
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
47 |
+
|
48 |
+
def fetch_text_to_voice(self, text) -> Reply:
|
49 |
+
return self.get_bot("text_to_voice").textToVoice(text)
|
50 |
+
|
bridge/context.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
class ContextType (Enum):
|
6 |
+
TEXT = 1 # 文本消息
|
7 |
+
VOICE = 2 # 音频消息
|
8 |
+
IMAGE_CREATE = 3 # 创建图片命令
|
9 |
+
|
10 |
+
def __str__(self):
|
11 |
+
return self.name
|
12 |
+
class Context:
|
13 |
+
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
14 |
+
self.type = type
|
15 |
+
self.content = content
|
16 |
+
self.kwargs = kwargs
|
17 |
+
|
18 |
+
def __contains__(self, key):
|
19 |
+
if key == 'type':
|
20 |
+
return self.type is not None
|
21 |
+
elif key == 'content':
|
22 |
+
return self.content is not None
|
23 |
+
else:
|
24 |
+
return key in self.kwargs
|
25 |
+
|
26 |
+
def __getitem__(self, key):
|
27 |
+
if key == 'type':
|
28 |
+
return self.type
|
29 |
+
elif key == 'content':
|
30 |
+
return self.content
|
31 |
+
else:
|
32 |
+
return self.kwargs[key]
|
33 |
+
|
34 |
+
def get(self, key, default=None):
|
35 |
+
try:
|
36 |
+
return self[key]
|
37 |
+
except KeyError:
|
38 |
+
return default
|
39 |
+
|
40 |
+
def __setitem__(self, key, value):
|
41 |
+
if key == 'type':
|
42 |
+
self.type = value
|
43 |
+
elif key == 'content':
|
44 |
+
self.content = value
|
45 |
+
else:
|
46 |
+
self.kwargs[key] = value
|
47 |
+
|
48 |
+
def __delitem__(self, key):
|
49 |
+
if key == 'type':
|
50 |
+
self.type = None
|
51 |
+
elif key == 'content':
|
52 |
+
self.content = None
|
53 |
+
else:
|
54 |
+
del self.kwargs[key]
|
55 |
+
|
56 |
+
def __str__(self):
|
57 |
+
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
bridge/reply.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# encoding:utf-8
|
3 |
+
|
4 |
+
from enum import Enum
|
5 |
+
|
6 |
+
class ReplyType(Enum):
|
7 |
+
TEXT = 1 # 文本
|
8 |
+
VOICE = 2 # 音频文件
|
9 |
+
IMAGE = 3 # 图片文件
|
10 |
+
IMAGE_URL = 4 # 图片URL
|
11 |
+
|
12 |
+
INFO = 9
|
13 |
+
ERROR = 10
|
14 |
+
def __str__(self):
|
15 |
+
return self.name
|
16 |
+
|
17 |
+
class Reply:
|
18 |
+
def __init__(self, type : ReplyType = None , content = None):
|
19 |
+
self.type = type
|
20 |
+
self.content = content
|
21 |
+
def __str__(self):
|
22 |
+
return "Reply(type={}, content={})".format(self.type, self.content)
|
channel/channel.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Message sending channel abstract class
|
3 |
+
"""
|
4 |
+
|
5 |
+
from bridge.bridge import Bridge
|
6 |
+
from bridge.context import Context
|
7 |
+
from bridge.reply import *
|
8 |
+
|
9 |
+
class Channel(object):
|
10 |
+
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
11 |
+
def startup(self):
|
12 |
+
"""
|
13 |
+
init channel
|
14 |
+
"""
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def handle_text(self, msg):
|
18 |
+
"""
|
19 |
+
process received msg
|
20 |
+
:param msg: message object
|
21 |
+
"""
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
25 |
+
def send(self, reply: Reply, context: Context):
|
26 |
+
"""
|
27 |
+
send message to user
|
28 |
+
:param msg: message content
|
29 |
+
:param receiver: receiver channel account
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
def build_reply_content(self, query, context : Context=None) -> Reply:
|
35 |
+
return Bridge().fetch_reply_content(query, context)
|
36 |
+
|
37 |
+
def build_voice_to_text(self, voice_file) -> Reply:
|
38 |
+
return Bridge().fetch_voice_to_text(voice_file)
|
39 |
+
|
40 |
+
def build_text_to_voice(self, text) -> Reply:
|
41 |
+
return Bridge().fetch_text_to_voice(text)
|
channel/channel_factory.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
channel factory
|
3 |
+
"""
|
4 |
+
|
5 |
+
def create_channel(channel_type):
|
6 |
+
"""
|
7 |
+
create a channel instance
|
8 |
+
:param channel_type: channel type code
|
9 |
+
:return: channel instance
|
10 |
+
"""
|
11 |
+
if channel_type == 'wx':
|
12 |
+
from channel.wechat.wechat_channel import WechatChannel
|
13 |
+
return WechatChannel()
|
14 |
+
elif channel_type == 'wxy':
|
15 |
+
from channel.wechat.wechaty_channel import WechatyChannel
|
16 |
+
return WechatyChannel()
|
17 |
+
elif channel_type == 'terminal':
|
18 |
+
from channel.terminal.terminal_channel import TerminalChannel
|
19 |
+
return TerminalChannel()
|
20 |
+
elif channel_type == 'wechatmp':
|
21 |
+
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
22 |
+
return WechatMPChannel()
|
23 |
+
raise RuntimeError
|
channel/chat_channel.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from asyncio import CancelledError
|
4 |
+
from concurrent.futures import Future, ThreadPoolExecutor
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
from common.dequeue import Dequeue
|
10 |
+
from channel.channel import Channel
|
11 |
+
from bridge.reply import *
|
12 |
+
from bridge.context import *
|
13 |
+
from config import conf
|
14 |
+
from common.log import logger
|
15 |
+
from plugins import *
|
16 |
+
try:
|
17 |
+
from voice.audio_convert import any_to_wav
|
18 |
+
except Exception as e:
|
19 |
+
pass
|
20 |
+
|
21 |
+
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
22 |
+
class ChatChannel(Channel):
|
23 |
+
name = None # 登录的用户名
|
24 |
+
user_id = None # 登录的用户id
|
25 |
+
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
26 |
+
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
27 |
+
lock = threading.Lock() # 用于控制对sessions的访问
|
28 |
+
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
_thread = threading.Thread(target=self.consume)
|
32 |
+
_thread.setDaemon(True)
|
33 |
+
_thread.start()
|
34 |
+
|
35 |
+
|
36 |
+
# 根据消息构造context,消息内容相关的触发项写在这里
|
37 |
+
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
38 |
+
context = Context(ctype, content)
|
39 |
+
context.kwargs = kwargs
|
40 |
+
# context首次传入时,origin_ctype是None,
|
41 |
+
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
42 |
+
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
43 |
+
if 'origin_ctype' not in context:
|
44 |
+
context['origin_ctype'] = ctype
|
45 |
+
# context首次传入时,receiver是None,根据类型设置receiver
|
46 |
+
first_in = 'receiver' not in context
|
47 |
+
# 群名匹配过程,设置session_id和receiver
|
48 |
+
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
49 |
+
config = conf()
|
50 |
+
cmsg = context['msg']
|
51 |
+
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
|
52 |
+
logger.debug("[WX]self message skipped")
|
53 |
+
return None
|
54 |
+
if context["isgroup"]:
|
55 |
+
group_name = cmsg.other_user_nickname
|
56 |
+
group_id = cmsg.other_user_id
|
57 |
+
|
58 |
+
group_name_white_list = config.get('group_name_white_list', [])
|
59 |
+
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
60 |
+
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
61 |
+
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
62 |
+
session_id = cmsg.actual_user_id
|
63 |
+
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
64 |
+
session_id = group_id
|
65 |
+
else:
|
66 |
+
return None
|
67 |
+
context['session_id'] = session_id
|
68 |
+
context['receiver'] = group_id
|
69 |
+
else:
|
70 |
+
context['session_id'] = cmsg.other_user_id
|
71 |
+
context['receiver'] = cmsg.other_user_id
|
72 |
+
|
73 |
+
# 消息内容匹配过程,并处理content
|
74 |
+
if ctype == ContextType.TEXT:
|
75 |
+
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
76 |
+
logger.debug("[WX]reference query skipped")
|
77 |
+
return None
|
78 |
+
|
79 |
+
if context["isgroup"]: # 群聊
|
80 |
+
# 校验关键字
|
81 |
+
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
82 |
+
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
83 |
+
flag = False
|
84 |
+
if match_prefix is not None or match_contain is not None:
|
85 |
+
flag = True
|
86 |
+
if match_prefix:
|
87 |
+
content = content.replace(match_prefix, '', 1).strip()
|
88 |
+
if context['msg'].is_at:
|
89 |
+
logger.info("[WX]receive group at")
|
90 |
+
if not conf().get("group_at_off", False):
|
91 |
+
flag = True
|
92 |
+
pattern = f'@{self.name}(\u2005|\u0020)'
|
93 |
+
content = re.sub(pattern, r'', content)
|
94 |
+
|
95 |
+
if not flag:
|
96 |
+
if context["origin_ctype"] == ContextType.VOICE:
|
97 |
+
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
98 |
+
return None
|
99 |
+
else: # 单聊
|
100 |
+
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
101 |
+
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
102 |
+
content = content.replace(match_prefix, '', 1).strip()
|
103 |
+
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
104 |
+
pass
|
105 |
+
else:
|
106 |
+
return None
|
107 |
+
|
108 |
+
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
109 |
+
if img_match_prefix:
|
110 |
+
content = content.replace(img_match_prefix, '', 1).strip()
|
111 |
+
context.type = ContextType.IMAGE_CREATE
|
112 |
+
else:
|
113 |
+
context.type = ContextType.TEXT
|
114 |
+
context.content = content
|
115 |
+
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
116 |
+
context['desire_rtype'] = ReplyType.VOICE
|
117 |
+
elif context.type == ContextType.VOICE:
|
118 |
+
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
119 |
+
context['desire_rtype'] = ReplyType.VOICE
|
120 |
+
|
121 |
+
return context
|
122 |
+
|
123 |
+
def _handle(self, context: Context):
|
124 |
+
if context is None or not context.content:
|
125 |
+
return
|
126 |
+
logger.debug('[WX] ready to handle context: {}'.format(context))
|
127 |
+
# reply的构建步骤
|
128 |
+
reply = self._generate_reply(context)
|
129 |
+
|
130 |
+
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
131 |
+
# reply的包装步骤
|
132 |
+
reply = self._decorate_reply(context, reply)
|
133 |
+
|
134 |
+
# reply的发送步骤
|
135 |
+
self._send_reply(context, reply)
|
136 |
+
|
137 |
+
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
138 |
+
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
139 |
+
'channel': self, 'context': context, 'reply': reply}))
|
140 |
+
reply = e_context['reply']
|
141 |
+
if not e_context.is_pass():
|
142 |
+
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
143 |
+
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
144 |
+
reply = super().build_reply_content(context.content, context)
|
145 |
+
elif context.type == ContextType.VOICE: # 语音消息
|
146 |
+
cmsg = context['msg']
|
147 |
+
cmsg.prepare()
|
148 |
+
file_path = context.content
|
149 |
+
wav_path = os.path.splitext(file_path)[0] + '.wav'
|
150 |
+
try:
|
151 |
+
any_to_wav(file_path, wav_path)
|
152 |
+
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
153 |
+
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
154 |
+
wav_path = file_path
|
155 |
+
# 语音识别
|
156 |
+
reply = super().build_voice_to_text(wav_path)
|
157 |
+
# 删除临时文件
|
158 |
+
try:
|
159 |
+
os.remove(file_path)
|
160 |
+
if wav_path != file_path:
|
161 |
+
os.remove(wav_path)
|
162 |
+
except Exception as e:
|
163 |
+
pass
|
164 |
+
# logger.warning("[WX]delete temp file error: " + str(e))
|
165 |
+
|
166 |
+
if reply.type == ReplyType.TEXT:
|
167 |
+
new_context = self._compose_context(
|
168 |
+
ContextType.TEXT, reply.content, **context.kwargs)
|
169 |
+
if new_context:
|
170 |
+
reply = self._generate_reply(new_context)
|
171 |
+
else:
|
172 |
+
return
|
173 |
+
else:
|
174 |
+
logger.error('[WX] unknown context type: {}'.format(context.type))
|
175 |
+
return
|
176 |
+
return reply
|
177 |
+
|
178 |
+
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
179 |
+
if reply and reply.type:
|
180 |
+
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
181 |
+
'channel': self, 'context': context, 'reply': reply}))
|
182 |
+
reply = e_context['reply']
|
183 |
+
desire_rtype = context.get('desire_rtype')
|
184 |
+
if not e_context.is_pass() and reply and reply.type:
|
185 |
+
|
186 |
+
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
187 |
+
logger.error("[WX]reply type not support: " + str(reply.type))
|
188 |
+
reply.type = ReplyType.ERROR
|
189 |
+
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
190 |
+
|
191 |
+
if reply.type == ReplyType.TEXT:
|
192 |
+
reply_text = reply.content
|
193 |
+
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
194 |
+
reply = super().build_text_to_voice(reply.content)
|
195 |
+
return self._decorate_reply(context, reply)
|
196 |
+
if context['isgroup']:
|
197 |
+
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
198 |
+
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
199 |
+
else:
|
200 |
+
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
201 |
+
reply.content = reply_text
|
202 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
203 |
+
reply.content = "["+str(reply.type)+"]\n" + reply.content
|
204 |
+
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
205 |
+
pass
|
206 |
+
else:
|
207 |
+
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
208 |
+
return
|
209 |
+
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
210 |
+
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
211 |
+
return reply
|
212 |
+
|
213 |
+
def _send_reply(self, context: Context, reply: Reply):
|
214 |
+
if reply and reply.type:
|
215 |
+
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
216 |
+
'channel': self, 'context': context, 'reply': reply}))
|
217 |
+
reply = e_context['reply']
|
218 |
+
if not e_context.is_pass() and reply and reply.type:
|
219 |
+
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
|
220 |
+
self._send(reply, context)
|
221 |
+
|
222 |
+
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
223 |
+
try:
|
224 |
+
self.send(reply, context)
|
225 |
+
except Exception as e:
|
226 |
+
logger.error('[WX] sendMsg error: {}'.format(str(e)))
|
227 |
+
if isinstance(e, NotImplementedError):
|
228 |
+
return
|
229 |
+
logger.exception(e)
|
230 |
+
if retry_cnt < 2:
|
231 |
+
time.sleep(3+3*retry_cnt)
|
232 |
+
self._send(reply, context, retry_cnt+1)
|
233 |
+
|
234 |
+
def thread_pool_callback(self, session_id):
|
235 |
+
def func(worker:Future):
|
236 |
+
try:
|
237 |
+
worker_exception = worker.exception()
|
238 |
+
if worker_exception:
|
239 |
+
logger.exception("Worker return exception: {}".format(worker_exception))
|
240 |
+
except CancelledError as e:
|
241 |
+
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
242 |
+
except Exception as e:
|
243 |
+
logger.exception("Worker raise exception: {}".format(e))
|
244 |
+
with self.lock:
|
245 |
+
self.sessions[session_id][1].release()
|
246 |
+
return func
|
247 |
+
|
248 |
+
def produce(self, context: Context):
|
249 |
+
session_id = context['session_id']
|
250 |
+
with self.lock:
|
251 |
+
if session_id not in self.sessions:
|
252 |
+
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
|
253 |
+
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
254 |
+
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
255 |
+
else:
|
256 |
+
self.sessions[session_id][0].put(context)
|
257 |
+
|
258 |
+
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
259 |
+
def consume(self):
|
260 |
+
while True:
|
261 |
+
with self.lock:
|
262 |
+
session_ids = list(self.sessions.keys())
|
263 |
+
for session_id in session_ids:
|
264 |
+
context_queue, semaphore = self.sessions[session_id]
|
265 |
+
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
266 |
+
if not context_queue.empty():
|
267 |
+
context = context_queue.get()
|
268 |
+
logger.debug("[WX] consume context: {}".format(context))
|
269 |
+
future:Future = self.handler_pool.submit(self._handle, context)
|
270 |
+
future.add_done_callback(self.thread_pool_callback(session_id))
|
271 |
+
if session_id not in self.futures:
|
272 |
+
self.futures[session_id] = []
|
273 |
+
self.futures[session_id].append(future)
|
274 |
+
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
275 |
+
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
276 |
+
assert len(self.futures[session_id]) == 0, "thread pool error"
|
277 |
+
del self.sessions[session_id]
|
278 |
+
else:
|
279 |
+
semaphore.release()
|
280 |
+
time.sleep(0.1)
|
281 |
+
|
282 |
+
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
283 |
+
def cancel_session(self, session_id):
|
284 |
+
with self.lock:
|
285 |
+
if session_id in self.sessions:
|
286 |
+
for future in self.futures[session_id]:
|
287 |
+
future.cancel()
|
288 |
+
cnt = self.sessions[session_id][0].qsize()
|
289 |
+
if cnt>0:
|
290 |
+
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
291 |
+
self.sessions[session_id][0] = Dequeue()
|
292 |
+
|
293 |
+
def cancel_all_session(self):
|
294 |
+
with self.lock:
|
295 |
+
for session_id in self.sessions:
|
296 |
+
for future in self.futures[session_id]:
|
297 |
+
future.cancel()
|
298 |
+
cnt = self.sessions[session_id][0].qsize()
|
299 |
+
if cnt>0:
|
300 |
+
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
301 |
+
self.sessions[session_id][0] = Dequeue()
|
302 |
+
|
303 |
+
|
304 |
+
def check_prefix(content, prefix_list):
|
305 |
+
for prefix in prefix_list:
|
306 |
+
if content.startswith(prefix):
|
307 |
+
return prefix
|
308 |
+
return None
|
309 |
+
|
310 |
+
def check_contain(content, keyword_list):
|
311 |
+
if not keyword_list:
|
312 |
+
return None
|
313 |
+
for ky in keyword_list:
|
314 |
+
if content.find(ky) != -1:
|
315 |
+
return True
|
316 |
+
return None
|
channel/chat_message.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装
|
4 |
+
|
5 |
+
ChatMessage
|
6 |
+
msg_id: 消息id
|
7 |
+
create_time: 消息创建时间
|
8 |
+
|
9 |
+
ctype: 消息类型 : ContextType
|
10 |
+
content: 消息内容, 如果是声音/图片,这里是文件路径
|
11 |
+
|
12 |
+
from_user_id: 发送者id
|
13 |
+
from_user_nickname: 发送者昵称
|
14 |
+
to_user_id: 接收者id
|
15 |
+
to_user_nickname: 接收者昵称
|
16 |
+
|
17 |
+
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id
|
18 |
+
other_user_nickname: 同上
|
19 |
+
|
20 |
+
is_group: 是否是群消息
|
21 |
+
is_at: 是否被at
|
22 |
+
|
23 |
+
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
24 |
+
actual_user_id: 实际发送者id
|
25 |
+
actual_user_nickname:实际发送者昵称
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
31 |
+
_prepared: 是否已经调用过准备函数
|
32 |
+
_rawmsg: 原始消息对象
|
33 |
+
|
34 |
+
"""
|
35 |
+
class ChatMessage(object):
|
36 |
+
msg_id = None
|
37 |
+
create_time = None
|
38 |
+
|
39 |
+
ctype = None
|
40 |
+
content = None
|
41 |
+
|
42 |
+
from_user_id = None
|
43 |
+
from_user_nickname = None
|
44 |
+
to_user_id = None
|
45 |
+
to_user_nickname = None
|
46 |
+
other_user_id = None
|
47 |
+
other_user_nickname = None
|
48 |
+
|
49 |
+
is_group = False
|
50 |
+
is_at = False
|
51 |
+
actual_user_id = None
|
52 |
+
actual_user_nickname = None
|
53 |
+
|
54 |
+
_prepare_fn = None
|
55 |
+
_prepared = False
|
56 |
+
_rawmsg = None
|
57 |
+
|
58 |
+
|
59 |
+
def __init__(self,_rawmsg):
|
60 |
+
self._rawmsg = _rawmsg
|
61 |
+
|
62 |
+
def prepare(self):
|
63 |
+
if self._prepare_fn and not self._prepared:
|
64 |
+
self._prepared = True
|
65 |
+
self._prepare_fn()
|
66 |
+
|
67 |
+
def __str__(self):
|
68 |
+
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
|
69 |
+
self.msg_id,
|
70 |
+
self.create_time,
|
71 |
+
self.ctype,
|
72 |
+
self.content,
|
73 |
+
self.from_user_id,
|
74 |
+
self.from_user_nickname,
|
75 |
+
self.to_user_id,
|
76 |
+
self.to_user_nickname,
|
77 |
+
self.other_user_id,
|
78 |
+
self.other_user_nickname,
|
79 |
+
self.is_group,
|
80 |
+
self.is_at,
|
81 |
+
self.actual_user_id,
|
82 |
+
self.actual_user_nickname,
|
83 |
+
)
|
channel/terminal/terminal_channel.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bridge.context import *
|
2 |
+
from channel.channel import Channel
|
3 |
+
import sys
|
4 |
+
|
5 |
+
class TerminalChannel(Channel):
|
6 |
+
def startup(self):
|
7 |
+
context = Context()
|
8 |
+
print("\nPlease input your question")
|
9 |
+
while True:
|
10 |
+
try:
|
11 |
+
prompt = self.get_input("User:\n")
|
12 |
+
except KeyboardInterrupt:
|
13 |
+
print("\nExiting...")
|
14 |
+
sys.exit()
|
15 |
+
|
16 |
+
context.type = ContextType.TEXT
|
17 |
+
context['session_id'] = "User"
|
18 |
+
context.content = prompt
|
19 |
+
print("Bot:")
|
20 |
+
sys.stdout.flush()
|
21 |
+
res = super().build_reply_content(prompt, context).content
|
22 |
+
print(res)
|
23 |
+
|
24 |
+
|
25 |
+
def get_input(self, prompt):
|
26 |
+
"""
|
27 |
+
Multi-line input function
|
28 |
+
"""
|
29 |
+
print(prompt, end="")
|
30 |
+
line = input()
|
31 |
+
return line
|
channel/wechat/wechat_channel.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
wechat channel
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import threading
|
9 |
+
import requests
|
10 |
+
import io
|
11 |
+
import time
|
12 |
+
import json
|
13 |
+
from channel.chat_channel import ChatChannel
|
14 |
+
from channel.wechat.wechat_message import *
|
15 |
+
from common.singleton import singleton
|
16 |
+
from common.log import logger
|
17 |
+
from lib import itchat
|
18 |
+
from lib.itchat.content import *
|
19 |
+
from bridge.reply import *
|
20 |
+
from bridge.context import *
|
21 |
+
from config import conf
|
22 |
+
from common.time_check import time_checker
|
23 |
+
from common.expired_dict import ExpiredDict
|
24 |
+
from plugins import *
|
25 |
+
|
26 |
+
@itchat.msg_register(TEXT)
|
27 |
+
def handler_single_msg(msg):
|
28 |
+
WechatChannel().handle_text(WeChatMessage(msg))
|
29 |
+
return None
|
30 |
+
|
31 |
+
@itchat.msg_register(TEXT, isGroupChat=True)
|
32 |
+
def handler_group_msg(msg):
|
33 |
+
WechatChannel().handle_group(WeChatMessage(msg,True))
|
34 |
+
return None
|
35 |
+
|
36 |
+
@itchat.msg_register(VOICE)
|
37 |
+
def handler_single_voice(msg):
|
38 |
+
WechatChannel().handle_voice(WeChatMessage(msg))
|
39 |
+
return None
|
40 |
+
|
41 |
+
@itchat.msg_register(VOICE, isGroupChat=True)
|
42 |
+
def handler_group_voice(msg):
|
43 |
+
WechatChannel().handle_group_voice(WeChatMessage(msg,True))
|
44 |
+
return None
|
45 |
+
|
46 |
+
def _check(func):
|
47 |
+
def wrapper(self, cmsg: ChatMessage):
|
48 |
+
msgId = cmsg.msg_id
|
49 |
+
if msgId in self.receivedMsgs:
|
50 |
+
logger.info("Wechat message {} already received, ignore".format(msgId))
|
51 |
+
return
|
52 |
+
self.receivedMsgs[msgId] = cmsg
|
53 |
+
create_time = cmsg.create_time # 消息时间戳
|
54 |
+
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
55 |
+
logger.debug("[WX]history message {} skipped".format(msgId))
|
56 |
+
return
|
57 |
+
return func(self, cmsg)
|
58 |
+
return wrapper
|
59 |
+
|
60 |
+
#可用的二维码生成接口
|
61 |
+
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
62 |
+
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
63 |
+
def qrCallback(uuid,status,qrcode):
|
64 |
+
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
65 |
+
if status == '0':
|
66 |
+
try:
|
67 |
+
from PIL import Image
|
68 |
+
img = Image.open(io.BytesIO(qrcode))
|
69 |
+
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
70 |
+
_thread.setDaemon(True)
|
71 |
+
_thread.start()
|
72 |
+
except Exception as e:
|
73 |
+
pass
|
74 |
+
|
75 |
+
import qrcode
|
76 |
+
url = f"https://login.weixin.qq.com/l/{uuid}"
|
77 |
+
|
78 |
+
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
79 |
+
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
80 |
+
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
81 |
+
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
82 |
+
print("You can also scan QRCode in any website below:")
|
83 |
+
print(qr_api3)
|
84 |
+
print(qr_api4)
|
85 |
+
print(qr_api2)
|
86 |
+
print(qr_api1)
|
87 |
+
|
88 |
+
qr = qrcode.QRCode(border=1)
|
89 |
+
qr.add_data(url)
|
90 |
+
qr.make(fit=True)
|
91 |
+
qr.print_ascii(invert=True)
|
92 |
+
|
93 |
+
@singleton
|
94 |
+
class WechatChannel(ChatChannel):
|
95 |
+
NOT_SUPPORT_REPLYTYPE = []
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
self.receivedMsgs = ExpiredDict(60*60*24)
|
99 |
+
|
100 |
+
def startup(self):
|
101 |
+
|
102 |
+
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
103 |
+
# login by scan QRCode
|
104 |
+
hotReload = conf().get('hot_reload', False)
|
105 |
+
try:
|
106 |
+
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
107 |
+
except Exception as e:
|
108 |
+
if hotReload:
|
109 |
+
logger.error("Hot reload failed, try to login without hot reload")
|
110 |
+
itchat.logout()
|
111 |
+
os.remove("itchat.pkl")
|
112 |
+
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
113 |
+
else:
|
114 |
+
raise e
|
115 |
+
self.user_id = itchat.instance.storageClass.userName
|
116 |
+
self.name = itchat.instance.storageClass.nickName
|
117 |
+
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
118 |
+
# start message listener
|
119 |
+
itchat.run()
|
120 |
+
|
121 |
+
# handle_* 系列函数处理收到的消息后构造Context,然后传入_handle函数中处理Context和发送回复
|
122 |
+
# Context包含了消息的所有信息,包括以下属性
|
123 |
+
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
124 |
+
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
125 |
+
# kwargs 附加参数字典,包含以下的key:
|
126 |
+
# session_id: 会话id
|
127 |
+
# isgroup: 是否是群聊
|
128 |
+
# receiver: 需要回复的对象
|
129 |
+
# msg: ChatMessage消息对象
|
130 |
+
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
131 |
+
# desire_rtype: 希望回复类型,默认是文本回复��设置为ReplyType.VOICE是语音回复
|
132 |
+
|
133 |
+
@time_checker
|
134 |
+
@_check
|
135 |
+
def handle_voice(self, cmsg : ChatMessage):
|
136 |
+
if conf().get('speech_recognition') != True:
|
137 |
+
return
|
138 |
+
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
139 |
+
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
|
140 |
+
if context:
|
141 |
+
self.produce(context)
|
142 |
+
|
143 |
+
@time_checker
|
144 |
+
@_check
|
145 |
+
def handle_text(self, cmsg : ChatMessage):
|
146 |
+
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
147 |
+
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
|
148 |
+
if context:
|
149 |
+
self.produce(context)
|
150 |
+
|
151 |
+
@time_checker
|
152 |
+
@_check
|
153 |
+
def handle_group(self, cmsg : ChatMessage):
|
154 |
+
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
155 |
+
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
|
156 |
+
if context:
|
157 |
+
self.produce(context)
|
158 |
+
|
159 |
+
@time_checker
|
160 |
+
@_check
|
161 |
+
def handle_group_voice(self, cmsg : ChatMessage):
|
162 |
+
if conf().get('group_speech_recognition', False) != True:
|
163 |
+
return
|
164 |
+
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
165 |
+
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
|
166 |
+
if context:
|
167 |
+
self.produce(context)
|
168 |
+
|
169 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
170 |
+
def send(self, reply: Reply, context: Context):
|
171 |
+
receiver = context["receiver"]
|
172 |
+
if reply.type == ReplyType.TEXT:
|
173 |
+
itchat.send(reply.content, toUserName=receiver)
|
174 |
+
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
175 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
176 |
+
itchat.send(reply.content, toUserName=receiver)
|
177 |
+
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
178 |
+
elif reply.type == ReplyType.VOICE:
|
179 |
+
itchat.send_file(reply.content, toUserName=receiver)
|
180 |
+
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
|
181 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
182 |
+
img_url = reply.content
|
183 |
+
pic_res = requests.get(img_url, stream=True)
|
184 |
+
image_storage = io.BytesIO()
|
185 |
+
for block in pic_res.iter_content(1024):
|
186 |
+
image_storage.write(block)
|
187 |
+
image_storage.seek(0)
|
188 |
+
itchat.send_image(image_storage, toUserName=receiver)
|
189 |
+
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
190 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
191 |
+
image_storage = reply.content
|
192 |
+
image_storage.seek(0)
|
193 |
+
itchat.send_image(image_storage, toUserName=receiver)
|
194 |
+
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
channel/wechat/wechat_message.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from bridge.context import ContextType
|
4 |
+
from channel.chat_message import ChatMessage
|
5 |
+
from common.tmp_dir import TmpDir
|
6 |
+
from common.log import logger
|
7 |
+
from lib.itchat.content import *
|
8 |
+
from lib import itchat
|
9 |
+
|
10 |
+
class WeChatMessage(ChatMessage):
|
11 |
+
|
12 |
+
def __init__(self, itchat_msg, is_group=False):
|
13 |
+
super().__init__( itchat_msg)
|
14 |
+
self.msg_id = itchat_msg['MsgId']
|
15 |
+
self.create_time = itchat_msg['CreateTime']
|
16 |
+
self.is_group = is_group
|
17 |
+
|
18 |
+
if itchat_msg['Type'] == TEXT:
|
19 |
+
self.ctype = ContextType.TEXT
|
20 |
+
self.content = itchat_msg['Text']
|
21 |
+
elif itchat_msg['Type'] == VOICE:
|
22 |
+
self.ctype = ContextType.VOICE
|
23 |
+
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
24 |
+
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
25 |
+
else:
|
26 |
+
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
|
27 |
+
|
28 |
+
self.from_user_id = itchat_msg['FromUserName']
|
29 |
+
self.to_user_id = itchat_msg['ToUserName']
|
30 |
+
|
31 |
+
user_id = itchat.instance.storageClass.userName
|
32 |
+
nickname = itchat.instance.storageClass.nickName
|
33 |
+
|
34 |
+
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
35 |
+
# 以下很繁琐,一句话总结:能填的都填了。
|
36 |
+
if self.from_user_id == user_id:
|
37 |
+
self.from_user_nickname = nickname
|
38 |
+
if self.to_user_id == user_id:
|
39 |
+
self.to_user_nickname = nickname
|
40 |
+
try: # 陌生人时候, 'User'字段可能不存在
|
41 |
+
self.other_user_id = itchat_msg['User']['UserName']
|
42 |
+
self.other_user_nickname = itchat_msg['User']['NickName']
|
43 |
+
if self.other_user_id == self.from_user_id:
|
44 |
+
self.from_user_nickname = self.other_user_nickname
|
45 |
+
if self.other_user_id == self.to_user_id:
|
46 |
+
self.to_user_nickname = self.other_user_nickname
|
47 |
+
except KeyError as e: # 处理偶尔没有对方信息的情况
|
48 |
+
logger.warn("[WX]get other_user_id failed: " + str(e))
|
49 |
+
if self.from_user_id == user_id:
|
50 |
+
self.other_user_id = self.to_user_id
|
51 |
+
else:
|
52 |
+
self.other_user_id = self.from_user_id
|
53 |
+
|
54 |
+
if self.is_group:
|
55 |
+
self.is_at = itchat_msg['IsAt']
|
56 |
+
self.actual_user_id = itchat_msg['ActualUserName']
|
57 |
+
self.actual_user_nickname = itchat_msg['ActualNickName']
|
channel/wechat/wechaty_channel.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
wechaty channel
|
5 |
+
Python Wechaty - https://github.com/wechaty/python-wechaty
|
6 |
+
"""
|
7 |
+
import base64
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import asyncio
|
11 |
+
from bridge.context import Context
|
12 |
+
from wechaty_puppet import FileBox
|
13 |
+
from wechaty import Wechaty, Contact
|
14 |
+
from wechaty.user import Message
|
15 |
+
from bridge.reply import *
|
16 |
+
from bridge.context import *
|
17 |
+
from channel.chat_channel import ChatChannel
|
18 |
+
from channel.wechat.wechaty_message import WechatyMessage
|
19 |
+
from common.log import logger
|
20 |
+
from common.singleton import singleton
|
21 |
+
from config import conf
|
22 |
+
try:
|
23 |
+
from voice.audio_convert import any_to_sil
|
24 |
+
except Exception as e:
|
25 |
+
pass
|
26 |
+
|
27 |
+
@singleton
|
28 |
+
class WechatyChannel(ChatChannel):
|
29 |
+
NOT_SUPPORT_REPLYTYPE = []
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
def startup(self):
|
34 |
+
config = conf()
|
35 |
+
token = config.get('wechaty_puppet_service_token')
|
36 |
+
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
37 |
+
asyncio.run(self.main())
|
38 |
+
|
39 |
+
async def main(self):
|
40 |
+
|
41 |
+
loop = asyncio.get_event_loop()
|
42 |
+
#将asyncio的loop传入处理线程
|
43 |
+
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
|
44 |
+
self.bot = Wechaty()
|
45 |
+
self.bot.on('login', self.on_login)
|
46 |
+
self.bot.on('message', self.on_message)
|
47 |
+
await self.bot.start()
|
48 |
+
|
49 |
+
async def on_login(self, contact: Contact):
|
50 |
+
self.user_id = contact.contact_id
|
51 |
+
self.name = contact.name
|
52 |
+
logger.info('[WX] login user={}'.format(contact))
|
53 |
+
|
54 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
55 |
+
def send(self, reply: Reply, context: Context):
|
56 |
+
receiver_id = context['receiver']
|
57 |
+
loop = asyncio.get_event_loop()
|
58 |
+
if context['isgroup']:
|
59 |
+
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
|
60 |
+
else:
|
61 |
+
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
|
62 |
+
msg = None
|
63 |
+
if reply.type == ReplyType.TEXT:
|
64 |
+
msg = reply.content
|
65 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
66 |
+
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
67 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
68 |
+
msg = reply.content
|
69 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
70 |
+
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
71 |
+
elif reply.type == ReplyType.VOICE:
|
72 |
+
voiceLength = None
|
73 |
+
file_path = reply.content
|
74 |
+
sil_file = os.path.splitext(file_path)[0] + '.sil'
|
75 |
+
voiceLength = int(any_to_sil(file_path, sil_file))
|
76 |
+
if voiceLength >= 60000:
|
77 |
+
voiceLength = 60000
|
78 |
+
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
|
79 |
+
# 发送语音
|
80 |
+
t = int(time.time())
|
81 |
+
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
|
82 |
+
if voiceLength is not None:
|
83 |
+
msg.metadata['voiceLength'] = voiceLength
|
84 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
85 |
+
try:
|
86 |
+
os.remove(file_path)
|
87 |
+
if sil_file != file_path:
|
88 |
+
os.remove(sil_file)
|
89 |
+
except Exception as e:
|
90 |
+
pass
|
91 |
+
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
|
92 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
93 |
+
img_url = reply.content
|
94 |
+
t = int(time.time())
|
95 |
+
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
96 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
97 |
+
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
98 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
99 |
+
image_storage = reply.content
|
100 |
+
image_storage.seek(0)
|
101 |
+
t = int(time.time())
|
102 |
+
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
|
103 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
104 |
+
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
105 |
+
|
106 |
+
async def on_message(self, msg: Message):
|
107 |
+
"""
|
108 |
+
listen for message event
|
109 |
+
"""
|
110 |
+
try:
|
111 |
+
cmsg = await WechatyMessage(msg)
|
112 |
+
except NotImplementedError as e:
|
113 |
+
logger.debug('[WX] {}'.format(e))
|
114 |
+
return
|
115 |
+
except Exception as e:
|
116 |
+
logger.exception('[WX] {}'.format(e))
|
117 |
+
return
|
118 |
+
logger.debug('[WX] message:{}'.format(cmsg))
|
119 |
+
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
120 |
+
isgroup = room is not None
|
121 |
+
ctype = cmsg.ctype
|
122 |
+
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
123 |
+
if context:
|
124 |
+
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
|
125 |
+
self.produce(context)
|
channel/wechat/wechaty_message.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
from wechaty import MessageType
|
4 |
+
from bridge.context import ContextType
|
5 |
+
from channel.chat_message import ChatMessage
|
6 |
+
from common.tmp_dir import TmpDir
|
7 |
+
from common.log import logger
|
8 |
+
from wechaty.user import Message
|
9 |
+
|
10 |
+
class aobject(object):
|
11 |
+
"""Inheriting this class allows you to define an async __init__.
|
12 |
+
|
13 |
+
So you can create objects by doing something like `await MyClass(params)`
|
14 |
+
"""
|
15 |
+
async def __new__(cls, *a, **kw):
|
16 |
+
instance = super().__new__(cls)
|
17 |
+
await instance.__init__(*a, **kw)
|
18 |
+
return instance
|
19 |
+
|
20 |
+
async def __init__(self):
|
21 |
+
pass
|
22 |
+
class WechatyMessage(ChatMessage, aobject):
|
23 |
+
|
24 |
+
async def __init__(self, wechaty_msg: Message):
|
25 |
+
super().__init__(wechaty_msg)
|
26 |
+
|
27 |
+
room = wechaty_msg.room()
|
28 |
+
|
29 |
+
self.msg_id = wechaty_msg.message_id
|
30 |
+
self.create_time = wechaty_msg.payload.timestamp
|
31 |
+
self.is_group = room is not None
|
32 |
+
|
33 |
+
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
34 |
+
self.ctype = ContextType.TEXT
|
35 |
+
self.content = wechaty_msg.text()
|
36 |
+
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
37 |
+
self.ctype = ContextType.VOICE
|
38 |
+
voice_file = await wechaty_msg.to_file_box()
|
39 |
+
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
40 |
+
|
41 |
+
def func():
|
42 |
+
loop = asyncio.get_event_loop()
|
43 |
+
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
|
44 |
+
self._prepare_fn = func
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
48 |
+
|
49 |
+
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
50 |
+
self.from_user_id = from_contact.contact_id
|
51 |
+
self.from_user_nickname = from_contact.name
|
52 |
+
|
53 |
+
# group中的from和to,wechaty跟itchat含义不一样
|
54 |
+
# wecahty: from是消息实际发送者, to:所在群
|
55 |
+
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
56 |
+
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
57 |
+
|
58 |
+
if self.is_group:
|
59 |
+
self.to_user_id = room.room_id
|
60 |
+
self.to_user_nickname = await room.topic()
|
61 |
+
else:
|
62 |
+
to_contact = wechaty_msg.to()
|
63 |
+
self.to_user_id = to_contact.contact_id
|
64 |
+
self.to_user_nickname = to_contact.name
|
65 |
+
|
66 |
+
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
67 |
+
self.other_user_id = self.to_user_id
|
68 |
+
self.other_user_nickname = self.to_user_nickname
|
69 |
+
else:
|
70 |
+
self.other_user_id = self.from_user_id
|
71 |
+
self.other_user_nickname = self.from_user_nickname
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
76 |
+
self.is_at = await wechaty_msg.mention_self()
|
77 |
+
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
78 |
+
name = wechaty_msg.wechaty.user_self().name
|
79 |
+
pattern = f'@{name}(\u2005|\u0020)'
|
80 |
+
if re.search(pattern,self.content):
|
81 |
+
logger.debug(f'wechaty message {self.msg_id} include at')
|
82 |
+
self.is_at = True
|
83 |
+
|
84 |
+
self.actual_user_id = self.from_user_id
|
85 |
+
self.actual_user_nickname = self.from_user_nickname
|
channel/wechatmp/README.md
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 个人微信公众号channel
|
2 |
+
|
3 |
+
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了个人微信公众号channel,提供无风险的服务。
|
4 |
+
但是由于个人微信公众号的众多接口限制,目前支持的功能有限,实现简陋,提供了一个最基本的文本对话服务,支持加载插件,优化了命令格式,支持私有api_key。暂未实现图片输入输出、语音输入输出等交互形式。
|
5 |
+
如有公众号是企业主体且可以通过微信认证,即可获得更多接口,解除大多数限制。欢迎大家提供更多的支持。
|
6 |
+
|
7 |
+
## 使用方法
|
8 |
+
|
9 |
+
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
10 |
+
|
11 |
+
此外,需要在我们的服务器上安装python的web框架web.py。
|
12 |
+
以ubuntu为例(在ubuntu 22.04上测试):
|
13 |
+
```
|
14 |
+
pip3 install web.py
|
15 |
+
```
|
16 |
+
|
17 |
+
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
|
18 |
+
|
19 |
+
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式。
|
20 |
+
|
21 |
+
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
22 |
+
```
|
23 |
+
"channel_type": "wechatmp",
|
24 |
+
"wechatmp_token": "your Token",
|
25 |
+
"wechatmp_port": 8080,
|
26 |
+
```
|
27 |
+
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
|
28 |
+
```
|
29 |
+
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
30 |
+
sudo iptables-save > /etc/iptables/rules.v4
|
31 |
+
```
|
32 |
+
第二个方法是让python程序直接监听80端口。这样可能会导致权限问题,在linux上需要使用`sudo`。然而这会导致后续缓存文件的权限问题,因此不是推荐的方法。
|
33 |
+
最后在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
34 |
+
|
35 |
+
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
36 |
+
|
37 |
+
## 个人微信公众号的限制
|
38 |
+
由于目前测试的公众号不是企业主体,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
39 |
+
|
40 |
+
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答拆分,分成每段600字回复(限制大约在700字)。
|
41 |
+
|
42 |
+
## 私有api_key
|
43 |
+
公共api有访问频率限制(免费账号每分钟最多20次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
44 |
+
|
45 |
+
## 测试范围
|
46 |
+
目前在`RoboStyle`这个公众号上进行了测试,感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有测试。百度的接口暂未测试。语音对话没有测试。图片直接以链接形式回复(没有临时素材上传接口的权限)。
|
channel/wechatmp/receive.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-#
|
2 |
+
# filename: receive.py
|
3 |
+
import xml.etree.ElementTree as ET
|
4 |
+
from bridge.context import ContextType
|
5 |
+
from channel.chat_message import ChatMessage
|
6 |
+
from common.log import logger
|
7 |
+
|
8 |
+
|
9 |
+
def parse_xml(web_data):
|
10 |
+
if len(web_data) == 0:
|
11 |
+
return None
|
12 |
+
xmlData = ET.fromstring(web_data)
|
13 |
+
return WeChatMPMessage(xmlData)
|
14 |
+
|
15 |
+
class WeChatMPMessage(ChatMessage):
|
16 |
+
def __init__(self, xmlData):
|
17 |
+
super().__init__(xmlData)
|
18 |
+
self.to_user_id = xmlData.find('ToUserName').text
|
19 |
+
self.from_user_id = xmlData.find('FromUserName').text
|
20 |
+
self.create_time = xmlData.find('CreateTime').text
|
21 |
+
self.msg_type = xmlData.find('MsgType').text
|
22 |
+
self.msg_id = xmlData.find('MsgId').text
|
23 |
+
self.is_group = False
|
24 |
+
|
25 |
+
# reply to other_user_id
|
26 |
+
self.other_user_id = self.from_user_id
|
27 |
+
|
28 |
+
if self.msg_type == 'text':
|
29 |
+
self.ctype = ContextType.TEXT
|
30 |
+
self.content = xmlData.find('Content').text.encode("utf-8")
|
31 |
+
elif self.msg_type == 'voice':
|
32 |
+
self.ctype = ContextType.TEXT
|
33 |
+
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
|
34 |
+
elif self.msg_type == 'image':
|
35 |
+
# not implemented
|
36 |
+
self.pic_url = xmlData.find('PicUrl').text
|
37 |
+
self.media_id = xmlData.find('MediaId').text
|
38 |
+
elif self.msg_type == 'event':
|
39 |
+
self.event = xmlData.find('Event').text
|
40 |
+
else: # video, shortvideo, location, link
|
41 |
+
# not implemented
|
42 |
+
pass
|
channel/wechatmp/reply.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-#
|
2 |
+
# filename: reply.py
|
3 |
+
import time
|
4 |
+
|
5 |
+
class Msg(object):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def send(self):
|
10 |
+
return "success"
|
11 |
+
|
12 |
+
class TextMsg(Msg):
|
13 |
+
def __init__(self, toUserName, fromUserName, content):
|
14 |
+
self.__dict = dict()
|
15 |
+
self.__dict['ToUserName'] = toUserName
|
16 |
+
self.__dict['FromUserName'] = fromUserName
|
17 |
+
self.__dict['CreateTime'] = int(time.time())
|
18 |
+
self.__dict['Content'] = content
|
19 |
+
|
20 |
+
def send(self):
|
21 |
+
XmlForm = """
|
22 |
+
<xml>
|
23 |
+
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
|
24 |
+
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
|
25 |
+
<CreateTime>{CreateTime}</CreateTime>
|
26 |
+
<MsgType><![CDATA[text]]></MsgType>
|
27 |
+
<Content><![CDATA[{Content}]]></Content>
|
28 |
+
</xml>
|
29 |
+
"""
|
30 |
+
return XmlForm.format(**self.__dict)
|
31 |
+
|
32 |
+
class ImageMsg(Msg):
|
33 |
+
def __init__(self, toUserName, fromUserName, mediaId):
|
34 |
+
self.__dict = dict()
|
35 |
+
self.__dict['ToUserName'] = toUserName
|
36 |
+
self.__dict['FromUserName'] = fromUserName
|
37 |
+
self.__dict['CreateTime'] = int(time.time())
|
38 |
+
self.__dict['MediaId'] = mediaId
|
39 |
+
|
40 |
+
def send(self):
|
41 |
+
XmlForm = """
|
42 |
+
<xml>
|
43 |
+
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
|
44 |
+
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
|
45 |
+
<CreateTime>{CreateTime}</CreateTime>
|
46 |
+
<MsgType><![CDATA[image]]></MsgType>
|
47 |
+
<Image>
|
48 |
+
<MediaId><![CDATA[{MediaId}]]></MediaId>
|
49 |
+
</Image>
|
50 |
+
</xml>
|
51 |
+
"""
|
52 |
+
return XmlForm.format(**self.__dict)
|
channel/wechatmp/wechatmp_channel.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import web
|
3 |
+
import time
|
4 |
+
import math
|
5 |
+
import hashlib
|
6 |
+
import textwrap
|
7 |
+
from channel.chat_channel import ChatChannel
|
8 |
+
import channel.wechatmp.reply as reply
|
9 |
+
import channel.wechatmp.receive as receive
|
10 |
+
from common.singleton import singleton
|
11 |
+
from common.log import logger
|
12 |
+
from config import conf
|
13 |
+
from bridge.reply import *
|
14 |
+
from bridge.context import *
|
15 |
+
from plugins import *
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
# If using SSL, uncomment the following lines, and modify the certificate path.
|
19 |
+
# from cheroot.server import HTTPServer
|
20 |
+
# from cheroot.ssl.builtin import BuiltinSSLAdapter
|
21 |
+
# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
|
22 |
+
# certificate='/ssl/cert.pem',
|
23 |
+
# private_key='/ssl/cert.key')
|
24 |
+
|
25 |
+
|
26 |
+
# from concurrent.futures import ThreadPoolExecutor
|
27 |
+
# thread_pool = ThreadPoolExecutor(max_workers=8)
|
28 |
+
|
29 |
+
@singleton
|
30 |
+
class WechatMPChannel(ChatChannel):
|
31 |
+
NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
self.cache_dict = dict()
|
35 |
+
self.query1 = dict()
|
36 |
+
self.query2 = dict()
|
37 |
+
self.query3 = dict()
|
38 |
+
|
39 |
+
|
40 |
+
def startup(self):
|
41 |
+
urls = (
|
42 |
+
'/wx', 'SubsribeAccountQuery',
|
43 |
+
)
|
44 |
+
app = web.application(urls, globals())
|
45 |
+
port = conf().get('wechatmp_port', 8080)
|
46 |
+
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
|
47 |
+
|
48 |
+
|
49 |
+
def send(self, reply: Reply, context: Context):
|
50 |
+
reply_cnt = math.ceil(len(reply.content) / 600)
|
51 |
+
receiver = context["receiver"]
|
52 |
+
self.cache_dict[receiver] = (reply_cnt, reply.content)
|
53 |
+
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
|
54 |
+
|
55 |
+
|
56 |
+
def verify_server():
|
57 |
+
try:
|
58 |
+
data = web.input()
|
59 |
+
if len(data) == 0:
|
60 |
+
return "None"
|
61 |
+
signature = data.signature
|
62 |
+
timestamp = data.timestamp
|
63 |
+
nonce = data.nonce
|
64 |
+
echostr = data.echostr
|
65 |
+
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
|
66 |
+
|
67 |
+
data_list = [token, timestamp, nonce]
|
68 |
+
data_list.sort()
|
69 |
+
sha1 = hashlib.sha1()
|
70 |
+
# map(sha1.update, data_list) #python2
|
71 |
+
sha1.update("".join(data_list).encode('utf-8'))
|
72 |
+
hashcode = sha1.hexdigest()
|
73 |
+
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
74 |
+
if hashcode == signature:
|
75 |
+
return echostr
|
76 |
+
else:
|
77 |
+
return ""
|
78 |
+
except Exception as Argument:
|
79 |
+
return Argument
|
80 |
+
|
81 |
+
|
82 |
+
# This class is instantiated once per query
|
83 |
+
class SubsribeAccountQuery():
|
84 |
+
|
85 |
+
def GET(self):
|
86 |
+
return verify_server()
|
87 |
+
|
88 |
+
def POST(self):
|
89 |
+
channel_instance = WechatMPChannel()
|
90 |
+
try:
|
91 |
+
query_time = time.time()
|
92 |
+
webData = web.data()
|
93 |
+
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
94 |
+
wechat_msg = receive.parse_xml(webData)
|
95 |
+
if wechat_msg.msg_type == 'text':
|
96 |
+
from_user = wechat_msg.from_user_id
|
97 |
+
to_user = wechat_msg.to_user_id
|
98 |
+
message = wechat_msg.content.decode("utf-8")
|
99 |
+
message_id = wechat_msg.msg_id
|
100 |
+
|
101 |
+
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
102 |
+
|
103 |
+
cache_key = from_user
|
104 |
+
cache = channel_instance.cache_dict.get(cache_key)
|
105 |
+
|
106 |
+
reply_text = ""
|
107 |
+
# New request
|
108 |
+
if cache == None:
|
109 |
+
# The first query begin, reset the cache
|
110 |
+
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg)
|
111 |
+
logger.debug("[wechatmp] context: {} {}".format(context, wechat_msg))
|
112 |
+
if context:
|
113 |
+
# set private openai_api_key
|
114 |
+
# if from_user is not changed in itchat, this can be placed at chat_channel
|
115 |
+
user_data = conf().get_user_data(from_user)
|
116 |
+
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
117 |
+
channel_instance.cache_dict[cache_key] = (0, "")
|
118 |
+
channel_instance.produce(context)
|
119 |
+
else:
|
120 |
+
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
121 |
+
if trigger_prefix:
|
122 |
+
content = textwrap.dedent(f"""\
|
123 |
+
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
124 |
+
例如:
|
125 |
+
{trigger_prefix}你好,很高兴见到你。""")
|
126 |
+
else:
|
127 |
+
logger.error(f"[wechatmp] unknown error")
|
128 |
+
content = textwrap.dedent("""\
|
129 |
+
未知错误,请稍后再试""")
|
130 |
+
replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
|
131 |
+
return replyMsg.send()
|
132 |
+
channel_instance.query1[cache_key] = False
|
133 |
+
channel_instance.query2[cache_key] = False
|
134 |
+
channel_instance.query3[cache_key] = False
|
135 |
+
# Request again
|
136 |
+
elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True:
|
137 |
+
channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True.
|
138 |
+
channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True.
|
139 |
+
channel_instance.query3[cache_key] = False
|
140 |
+
elif cache[0] >= 1:
|
141 |
+
# Skip the waiting phase
|
142 |
+
channel_instance.query1[cache_key] = True
|
143 |
+
channel_instance.query2[cache_key] = True
|
144 |
+
channel_instance.query3[cache_key] = True
|
145 |
+
|
146 |
+
|
147 |
+
cache = channel_instance.cache_dict.get(cache_key)
|
148 |
+
if channel_instance.query1.get(cache_key) == False:
|
149 |
+
# The first query from wechat official server
|
150 |
+
logger.debug("[wechatmp] query1 {}".format(cache_key))
|
151 |
+
channel_instance.query1[cache_key] = True
|
152 |
+
cnt = 0
|
153 |
+
while cache[0] == 0 and cnt < 45:
|
154 |
+
cnt = cnt + 1
|
155 |
+
time.sleep(0.1)
|
156 |
+
cache = channel_instance.cache_dict.get(cache_key)
|
157 |
+
if cnt == 45:
|
158 |
+
# waiting for timeout (the POST query will be closed by wechat official server)
|
159 |
+
time.sleep(5)
|
160 |
+
# and do nothing
|
161 |
+
return
|
162 |
+
else:
|
163 |
+
pass
|
164 |
+
elif channel_instance.query2.get(cache_key) == False:
|
165 |
+
# The second query from wechat official server
|
166 |
+
logger.debug("[wechatmp] query2 {}".format(cache_key))
|
167 |
+
channel_instance.query2[cache_key] = True
|
168 |
+
cnt = 0
|
169 |
+
while cache[0] == 0 and cnt < 45:
|
170 |
+
cnt = cnt + 1
|
171 |
+
time.sleep(0.1)
|
172 |
+
cache = channel_instance.cache_dict.get(cache_key)
|
173 |
+
if cnt == 45:
|
174 |
+
# waiting for timeout (the POST query will be closed by wechat official server)
|
175 |
+
time.sleep(5)
|
176 |
+
# and do nothing
|
177 |
+
return
|
178 |
+
else:
|
179 |
+
pass
|
180 |
+
elif channel_instance.query3.get(cache_key) == False:
|
181 |
+
# The third query from wechat official server
|
182 |
+
logger.debug("[wechatmp] query3 {}".format(cache_key))
|
183 |
+
channel_instance.query3[cache_key] = True
|
184 |
+
cnt = 0
|
185 |
+
while cache[0] == 0 and cnt < 40:
|
186 |
+
cnt = cnt + 1
|
187 |
+
time.sleep(0.1)
|
188 |
+
cache = channel_instance.cache_dict.get(cache_key)
|
189 |
+
if cnt == 40:
|
190 |
+
# Have waiting for 3x5 seconds
|
191 |
+
# return timeout message
|
192 |
+
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
193 |
+
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
|
194 |
+
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
195 |
+
return replyPost
|
196 |
+
else:
|
197 |
+
pass
|
198 |
+
|
199 |
+
if float(time.time()) - float(query_time) > 4.8:
|
200 |
+
logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id))
|
201 |
+
return
|
202 |
+
|
203 |
+
|
204 |
+
if cache[0] > 1:
|
205 |
+
reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit
|
206 |
+
channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:])
|
207 |
+
elif cache[0] == 1:
|
208 |
+
reply_text = cache[1]
|
209 |
+
channel_instance.cache_dict.pop(cache_key)
|
210 |
+
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
|
211 |
+
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
212 |
+
return replyPost
|
213 |
+
|
214 |
+
elif wechat_msg.msg_type == 'event':
|
215 |
+
logger.info("[wechatmp] Event {} from {}".format(wechat_msg.Event, wechat_msg.from_user_id))
|
216 |
+
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
217 |
+
content = textwrap.dedent(f"""\
|
218 |
+
感谢您的关注!
|
219 |
+
��里是ChatGPT,可以自由对话。
|
220 |
+
资源有限,回复较慢,请勿着急。
|
221 |
+
支持通用表情输入。
|
222 |
+
暂时不支持图片输入。
|
223 |
+
支持图片输出,画字开头的问题将回复图片链接。
|
224 |
+
支持角色扮演和文字冒险两种定制模式对话。
|
225 |
+
输入'{trigger_prefix}#帮助' 查看详细指令。""")
|
226 |
+
replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
|
227 |
+
return replyMsg.send()
|
228 |
+
else:
|
229 |
+
logger.info("暂且不处理")
|
230 |
+
return "success"
|
231 |
+
except Exception as exc:
|
232 |
+
logger.exception(exc)
|
233 |
+
return exc
|
234 |
+
|
common/const.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# bot_type
|
2 |
+
OPEN_AI = "openAI"
|
3 |
+
CHATGPT = "chatGPT"
|
4 |
+
BAIDU = "baidu"
|
5 |
+
CHATGPTONAZURE = "chatGPTOnAzure"
|
common/dequeue.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from queue import Full, Queue
|
3 |
+
from time import monotonic as time
|
4 |
+
|
5 |
+
# add implementation of putleft to Queue
|
6 |
+
class Dequeue(Queue):
|
7 |
+
def putleft(self, item, block=True, timeout=None):
|
8 |
+
with self.not_full:
|
9 |
+
if self.maxsize > 0:
|
10 |
+
if not block:
|
11 |
+
if self._qsize() >= self.maxsize:
|
12 |
+
raise Full
|
13 |
+
elif timeout is None:
|
14 |
+
while self._qsize() >= self.maxsize:
|
15 |
+
self.not_full.wait()
|
16 |
+
elif timeout < 0:
|
17 |
+
raise ValueError("'timeout' must be a non-negative number")
|
18 |
+
else:
|
19 |
+
endtime = time() + timeout
|
20 |
+
while self._qsize() >= self.maxsize:
|
21 |
+
remaining = endtime - time()
|
22 |
+
if remaining <= 0.0:
|
23 |
+
raise Full
|
24 |
+
self.not_full.wait(remaining)
|
25 |
+
self._putleft(item)
|
26 |
+
self.unfinished_tasks += 1
|
27 |
+
self.not_empty.notify()
|
28 |
+
|
29 |
+
def putleft_nowait(self, item):
|
30 |
+
return self.putleft(item, block=False)
|
31 |
+
|
32 |
+
def _putleft(self, item):
|
33 |
+
self.queue.appendleft(item)
|
common/expired_dict.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timedelta
|
2 |
+
|
3 |
+
|
4 |
+
class ExpiredDict(dict):
|
5 |
+
def __init__(self, expires_in_seconds):
|
6 |
+
super().__init__()
|
7 |
+
self.expires_in_seconds = expires_in_seconds
|
8 |
+
|
9 |
+
def __getitem__(self, key):
|
10 |
+
value, expiry_time = super().__getitem__(key)
|
11 |
+
if datetime.now() > expiry_time:
|
12 |
+
del self[key]
|
13 |
+
raise KeyError("expired {}".format(key))
|
14 |
+
self.__setitem__(key, value)
|
15 |
+
return value
|
16 |
+
|
17 |
+
def __setitem__(self, key, value):
|
18 |
+
expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds)
|
19 |
+
super().__setitem__(key, (value, expiry_time))
|
20 |
+
|
21 |
+
def get(self, key, default=None):
|
22 |
+
try:
|
23 |
+
return self[key]
|
24 |
+
except KeyError:
|
25 |
+
return default
|
26 |
+
|
27 |
+
def __contains__(self, key):
|
28 |
+
try:
|
29 |
+
self[key]
|
30 |
+
return True
|
31 |
+
except KeyError:
|
32 |
+
return False
|
33 |
+
|
34 |
+
def keys(self):
|
35 |
+
keys = list(super().keys())
|
36 |
+
return [key for key in keys if key in self]
|
37 |
+
|
38 |
+
def items(self):
|
39 |
+
return [(key, self[key]) for key in self.keys()]
|
40 |
+
|
41 |
+
def __iter__(self):
|
42 |
+
return self.keys().__iter__()
|
common/log.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
def _get_logger():
|
6 |
+
log = logging.getLogger('log')
|
7 |
+
log.setLevel(logging.INFO)
|
8 |
+
console_handle = logging.StreamHandler(sys.stdout)
|
9 |
+
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
10 |
+
datefmt='%Y-%m-%d %H:%M:%S'))
|
11 |
+
file_handle = logging.FileHandler('run.log', encoding='utf-8')
|
12 |
+
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
13 |
+
datefmt='%Y-%m-%d %H:%M:%S'))
|
14 |
+
log.addHandler(file_handle)
|
15 |
+
log.addHandler(console_handle)
|
16 |
+
return log
|
17 |
+
|
18 |
+
|
19 |
+
# 日志句柄
|
20 |
+
logger = _get_logger()
|
common/singleton.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def singleton(cls):
|
2 |
+
instances = {}
|
3 |
+
|
4 |
+
def get_instance(*args, **kwargs):
|
5 |
+
if cls not in instances:
|
6 |
+
instances[cls] = cls(*args, **kwargs)
|
7 |
+
return instances[cls]
|
8 |
+
|
9 |
+
return get_instance
|
common/sorted_dict.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import heapq
|
2 |
+
|
3 |
+
|
4 |
+
class SortedDict(dict):
|
5 |
+
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
|
6 |
+
if init_dict is None:
|
7 |
+
init_dict = []
|
8 |
+
if isinstance(init_dict, dict):
|
9 |
+
init_dict = init_dict.items()
|
10 |
+
self.sort_func = sort_func
|
11 |
+
self.sorted_keys = None
|
12 |
+
self.reverse = reverse
|
13 |
+
self.heap = []
|
14 |
+
for k, v in init_dict:
|
15 |
+
self[k] = v
|
16 |
+
|
17 |
+
def __setitem__(self, key, value):
|
18 |
+
if key in self:
|
19 |
+
super().__setitem__(key, value)
|
20 |
+
for i, (priority, k) in enumerate(self.heap):
|
21 |
+
if k == key:
|
22 |
+
self.heap[i] = (self.sort_func(key, value), key)
|
23 |
+
heapq.heapify(self.heap)
|
24 |
+
break
|
25 |
+
self.sorted_keys = None
|
26 |
+
else:
|
27 |
+
super().__setitem__(key, value)
|
28 |
+
heapq.heappush(self.heap, (self.sort_func(key, value), key))
|
29 |
+
self.sorted_keys = None
|
30 |
+
|
31 |
+
def __delitem__(self, key):
|
32 |
+
super().__delitem__(key)
|
33 |
+
for i, (priority, k) in enumerate(self.heap):
|
34 |
+
if k == key:
|
35 |
+
del self.heap[i]
|
36 |
+
heapq.heapify(self.heap)
|
37 |
+
break
|
38 |
+
self.sorted_keys = None
|
39 |
+
|
40 |
+
def keys(self):
|
41 |
+
if self.sorted_keys is None:
|
42 |
+
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
43 |
+
return self.sorted_keys
|
44 |
+
|
45 |
+
def items(self):
|
46 |
+
if self.sorted_keys is None:
|
47 |
+
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
48 |
+
sorted_items = [(k, self[k]) for k in self.sorted_keys]
|
49 |
+
return sorted_items
|
50 |
+
|
51 |
+
def _update_heap(self, key):
|
52 |
+
for i, (priority, k) in enumerate(self.heap):
|
53 |
+
if k == key:
|
54 |
+
new_priority = self.sort_func(key, self[key])
|
55 |
+
if new_priority != priority:
|
56 |
+
self.heap[i] = (new_priority, key)
|
57 |
+
heapq.heapify(self.heap)
|
58 |
+
self.sorted_keys = None
|
59 |
+
break
|
60 |
+
|
61 |
+
def __iter__(self):
|
62 |
+
return iter(self.keys())
|
63 |
+
|
64 |
+
def __repr__(self):
|
65 |
+
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
|
common/time_check.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time,re,hashlib
|
2 |
+
import config
|
3 |
+
from common.log import logger
|
4 |
+
|
5 |
+
def time_checker(f):
|
6 |
+
def _time_checker(self, *args, **kwargs):
|
7 |
+
_config = config.conf()
|
8 |
+
chat_time_module = _config.get("chat_time_module", False)
|
9 |
+
if chat_time_module:
|
10 |
+
chat_start_time = _config.get("chat_start_time", "00:00")
|
11 |
+
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
12 |
+
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00
|
13 |
+
|
14 |
+
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
15 |
+
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
16 |
+
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
17 |
+
|
18 |
+
# 时间格式检查
|
19 |
+
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
20 |
+
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
|
21 |
+
if chat_start_time>"23:59":
|
22 |
+
logger.error('启动时间可能存在问题,请修改!')
|
23 |
+
|
24 |
+
# 服务时间检查
|
25 |
+
now_time = time.strftime("%H:%M", time.localtime())
|
26 |
+
if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
|
27 |
+
f(self, *args, **kwargs)
|
28 |
+
return None
|
29 |
+
else:
|
30 |
+
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置
|
31 |
+
f(self, *args, **kwargs)
|
32 |
+
else:
|
33 |
+
logger.info('非服务时间内,不接受访问')
|
34 |
+
return None
|
35 |
+
else:
|
36 |
+
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
37 |
+
return _time_checker
|
38 |
+
|
common/tmp_dir.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from config import conf
|
5 |
+
|
6 |
+
|
7 |
+
class TmpDir(object):
|
8 |
+
"""A temporary directory that is deleted when the object is destroyed.
|
9 |
+
"""
|
10 |
+
|
11 |
+
tmpFilePath = pathlib.Path('./tmp/')
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
pathExists = os.path.exists(self.tmpFilePath)
|
15 |
+
if not pathExists and conf().get('speech_recognition') == True:
|
16 |
+
os.makedirs(self.tmpFilePath)
|
17 |
+
|
18 |
+
def path(self):
|
19 |
+
return str(self.tmpFilePath) + '/'
|
20 |
+
|
common/token_bucket.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
|
4 |
+
|
5 |
+
class TokenBucket:
|
6 |
+
def __init__(self, tpm, timeout=None):
|
7 |
+
self.capacity = int(tpm) # 令牌桶容量
|
8 |
+
self.tokens = 0 # 初始令牌数为0
|
9 |
+
self.rate = int(tpm) / 60 # 令牌每秒生成速率
|
10 |
+
self.timeout = timeout # 等待令牌超时时间
|
11 |
+
self.cond = threading.Condition() # 条件变量
|
12 |
+
self.is_running = True
|
13 |
+
# 开启令牌生成线程
|
14 |
+
threading.Thread(target=self._generate_tokens).start()
|
15 |
+
|
16 |
+
def _generate_tokens(self):
|
17 |
+
"""生成令牌"""
|
18 |
+
while self.is_running:
|
19 |
+
with self.cond:
|
20 |
+
if self.tokens < self.capacity:
|
21 |
+
self.tokens += 1
|
22 |
+
self.cond.notify() # 通知获取令牌的线程
|
23 |
+
time.sleep(1 / self.rate)
|
24 |
+
|
25 |
+
def get_token(self):
|
26 |
+
"""获取令牌"""
|
27 |
+
with self.cond:
|
28 |
+
while self.tokens <= 0:
|
29 |
+
flag = self.cond.wait(self.timeout)
|
30 |
+
if not flag: # 超时
|
31 |
+
return False
|
32 |
+
self.tokens -= 1
|
33 |
+
return True
|
34 |
+
|
35 |
+
def close(self):
|
36 |
+
self.is_running = False
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
|
41 |
+
# token_bucket = TokenBucket(20, 0.1)
|
42 |
+
for i in range(3):
|
43 |
+
if token_bucket.get_token():
|
44 |
+
print(f"第{i+1}次请求成功")
|
45 |
+
token_bucket.close()
|
config-template.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"open_ai_api_key": "YOUR API KEY",
|
3 |
+
"model": "gpt-3.5-turbo",
|
4 |
+
"proxy": "",
|
5 |
+
"use_azure_chatgpt": false,
|
6 |
+
"single_chat_prefix": ["bot", "@bot"],
|
7 |
+
"single_chat_reply_prefix": "[bot] ",
|
8 |
+
"group_chat_prefix": ["@bot"],
|
9 |
+
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
|
10 |
+
"group_chat_in_one_session": ["ChatGPT测试群"],
|
11 |
+
"image_create_prefix": ["画", "看", "找"],
|
12 |
+
"speech_recognition": false,
|
13 |
+
"group_speech_recognition": false,
|
14 |
+
"voice_reply_voice": false,
|
15 |
+
"conversation_max_tokens": 1000,
|
16 |
+
"expires_in_seconds": 3600,
|
17 |
+
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
18 |
+
}
|
config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"open_ai_api_key": "YOUR API KEY",
|
3 |
+
"model": "gpt-3.5-turbo",
|
4 |
+
"proxy": "",
|
5 |
+
"use_azure_chatgpt": false,
|
6 |
+
"single_chat_prefix": [""],
|
7 |
+
"single_chat_reply_prefix": "[bot] ",
|
8 |
+
"group_chat_prefix": ["@bot"],
|
9 |
+
"group_name_white_list": [""],
|
10 |
+
"group_chat_in_one_session": [""],
|
11 |
+
"image_create_prefix": ["画", "看", "找"],
|
12 |
+
"speech_recognition": true,
|
13 |
+
"group_speech_recognition": false,
|
14 |
+
"voice_reply_voice": false,
|
15 |
+
"conversation_max_tokens": 1000,
|
16 |
+
"expires_in_seconds": 3600,
|
17 |
+
"character_desc": "你是接入微信的ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
18 |
+
}
|
config.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding:utf-8
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from common.log import logger
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
# 将所有可用的配置项写在字典里, 请使用小写字母
|
10 |
+
available_setting = {
|
11 |
+
# openai api配置
|
12 |
+
"open_ai_api_key": "", # openai api key
|
13 |
+
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
14 |
+
"open_ai_api_base": "https://api.openai.com/v1",
|
15 |
+
"proxy": "", # openai使用的代理
|
16 |
+
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
17 |
+
"model": "gpt-3.5-turbo",
|
18 |
+
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
19 |
+
|
20 |
+
# Bot触发配置
|
21 |
+
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
22 |
+
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
23 |
+
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
24 |
+
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
25 |
+
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
26 |
+
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
27 |
+
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
28 |
+
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
29 |
+
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
30 |
+
"trigger_by_self": False, # 是否允许机器人触发
|
31 |
+
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
32 |
+
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
33 |
+
|
34 |
+
# chatgpt会话参数
|
35 |
+
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
36 |
+
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
37 |
+
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
38 |
+
|
39 |
+
# chatgpt限流配置
|
40 |
+
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
41 |
+
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
42 |
+
|
43 |
+
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
44 |
+
"temperature": 0.9,
|
45 |
+
"top_p": 1,
|
46 |
+
"frequency_penalty": 0,
|
47 |
+
"presence_penalty": 0,
|
48 |
+
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
49 |
+
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
50 |
+
|
51 |
+
# 语音设置
|
52 |
+
"speech_recognition": False, # 是否开启语音识别
|
53 |
+
"group_speech_recognition": False, # 是否开启群组语音识别
|
54 |
+
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
55 |
+
"always_reply_voice": False, # 是否一直使用语音回复
|
56 |
+
"voice_to_text": "openai", # 语音识别引擎,支持openai,google,azure
|
57 |
+
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
58 |
+
|
59 |
+
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
60 |
+
"baidu_app_id": "",
|
61 |
+
"baidu_api_key": "",
|
62 |
+
"baidu_secret_key": "",
|
63 |
+
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
64 |
+
"baidu_dev_pid": "1536",
|
65 |
+
|
66 |
+
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
67 |
+
"azure_voice_api_key": "",
|
68 |
+
"azure_voice_region": "japaneast",
|
69 |
+
|
70 |
+
# 服务时间限制,目前支持itchat
|
71 |
+
"chat_time_module": False, # 是否开启服务时间限制
|
72 |
+
"chat_start_time": "00:00", # 服务开始时间
|
73 |
+
"chat_stop_time": "24:00", # 服务结束时间
|
74 |
+
|
75 |
+
# itchat的配置
|
76 |
+
"hot_reload": False, # 是否开启热重载
|
77 |
+
|
78 |
+
# wechaty的配置
|
79 |
+
"wechaty_puppet_service_token": "", # wechaty的token
|
80 |
+
|
81 |
+
# wechatmp的配置
|
82 |
+
"wechatmp_token": "", # 微信公众平台的Token
|
83 |
+
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
84 |
+
|
85 |
+
# chatgpt指令自定义触发词
|
86 |
+
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
|
87 |
+
|
88 |
+
# channel配置
|
89 |
+
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp}
|
90 |
+
|
91 |
+
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
92 |
+
|
93 |
+
# 插件配置
|
94 |
+
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
class Config(dict):
|
99 |
+
def __init__(self, d:dict={}):
|
100 |
+
super().__init__(d)
|
101 |
+
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
|
102 |
+
self.user_datas = {}
|
103 |
+
|
104 |
+
def __getitem__(self, key):
|
105 |
+
if key not in available_setting:
|
106 |
+
raise Exception("key {} not in available_setting".format(key))
|
107 |
+
return super().__getitem__(key)
|
108 |
+
|
109 |
+
def __setitem__(self, key, value):
|
110 |
+
if key not in available_setting:
|
111 |
+
raise Exception("key {} not in available_setting".format(key))
|
112 |
+
return super().__setitem__(key, value)
|
113 |
+
|
114 |
+
def get(self, key, default=None):
|
115 |
+
try:
|
116 |
+
return self[key]
|
117 |
+
except KeyError as e:
|
118 |
+
return default
|
119 |
+
except Exception as e:
|
120 |
+
raise e
|
121 |
+
|
122 |
+
# Make sure to return a dictionary to ensure atomic
|
123 |
+
def get_user_data(self, user) -> dict:
|
124 |
+
if self.user_datas.get(user) is None:
|
125 |
+
self.user_datas[user] = {}
|
126 |
+
return self.user_datas[user]
|
127 |
+
|
128 |
+
def load_user_datas(self):
|
129 |
+
try:
|
130 |
+
with open('user_datas.pkl', 'rb') as f:
|
131 |
+
self.user_datas = pickle.load(f)
|
132 |
+
logger.info("[Config] User datas loaded.")
|
133 |
+
except FileNotFoundError as e:
|
134 |
+
logger.info("[Config] User datas file not found, ignore.")
|
135 |
+
except Exception as e:
|
136 |
+
logger.info("[Config] User datas error: {}".format(e))
|
137 |
+
self.user_datas = {}
|
138 |
+
|
139 |
+
def save_user_datas(self):
|
140 |
+
try:
|
141 |
+
with open('user_datas.pkl', 'wb') as f:
|
142 |
+
pickle.dump(self.user_datas, f)
|
143 |
+
logger.info("[Config] User datas saved.")
|
144 |
+
except Exception as e:
|
145 |
+
logger.info("[Config] User datas error: {}".format(e))
|
146 |
+
|
147 |
+
config = Config()
|
148 |
+
|
149 |
+
|
150 |
+
def load_config():
|
151 |
+
global config
|
152 |
+
config_path = "./config.json"
|
153 |
+
if not os.path.exists(config_path):
|
154 |
+
logger.info('配置文件不存在,将使用config-template.json模板')
|
155 |
+
config_path = "./config-template.json"
|
156 |
+
|
157 |
+
config_str = read_file(config_path)
|
158 |
+
logger.debug("[INIT] config str: {}".format(config_str))
|
159 |
+
|
160 |
+
# 将json字符串反序列化为dict类型
|
161 |
+
config = Config(json.loads(config_str))
|
162 |
+
|
163 |
+
# override config with environment variables.
|
164 |
+
# Some online deployment platforms (e.g. Railway) deploy project from github directly. So you shouldn't put your secrets like api key in a config file, instead use environment variables to override the default config.
|
165 |
+
for name, value in os.environ.items():
|
166 |
+
name = name.lower()
|
167 |
+
if name in available_setting:
|
168 |
+
logger.info(
|
169 |
+
"[INIT] override config by environ args: {}={}".format(name, value))
|
170 |
+
try:
|
171 |
+
config[name] = eval(value)
|
172 |
+
except:
|
173 |
+
if value == "false":
|
174 |
+
config[name] = False
|
175 |
+
elif value == "true":
|
176 |
+
config[name] = True
|
177 |
+
else:
|
178 |
+
config[name] = value
|
179 |
+
|
180 |
+
if config.get("debug", False):
|
181 |
+
logger.setLevel(logging.DEBUG)
|
182 |
+
logger.debug("[INIT] set log level to DEBUG")
|
183 |
+
|
184 |
+
logger.info("[INIT] load config: {}".format(config))
|
185 |
+
|
186 |
+
config.load_user_datas()
|
187 |
+
|
188 |
+
def get_root():
|
189 |
+
return os.path.dirname(os.path.abspath(__file__))
|
190 |
+
|
191 |
+
|
192 |
+
def read_file(path):
|
193 |
+
with open(path, mode='r', encoding='utf-8') as f:
|
194 |
+
return f.read()
|
195 |
+
|
196 |
+
|
197 |
+
def conf():
|
198 |
+
return config
|
docker/Dockerfile.alpine
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-alpine
|
2 |
+
|
3 |
+
LABEL maintainer="[email protected]"
|
4 |
+
ARG TZ='Asia/Shanghai'
|
5 |
+
|
6 |
+
ARG CHATGPT_ON_WECHAT_VER
|
7 |
+
|
8 |
+
ENV BUILD_PREFIX=/app
|
9 |
+
|
10 |
+
RUN apk add --no-cache \
|
11 |
+
bash \
|
12 |
+
curl \
|
13 |
+
wget \
|
14 |
+
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
15 |
+
grep '"tag_name":' | \
|
16 |
+
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
17 |
+
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
18 |
+
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
19 |
+
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
20 |
+
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
21 |
+
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
22 |
+
&& cd ${BUILD_PREFIX} \
|
23 |
+
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
24 |
+
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
25 |
+
&& pip install --no-cache -r requirements.txt \
|
26 |
+
&& pip install --no-cache -r requirements-optional.txt \
|
27 |
+
&& apk del curl wget
|
28 |
+
|
29 |
+
WORKDIR ${BUILD_PREFIX}
|
30 |
+
|
31 |
+
ADD ./entrypoint.sh /entrypoint.sh
|
32 |
+
|
33 |
+
RUN chmod +x /entrypoint.sh \
|
34 |
+
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
35 |
+
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
36 |
+
|
37 |
+
USER noroot
|
38 |
+
|
39 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
docker/Dockerfile.debian
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
LABEL maintainer="[email protected]"
|
4 |
+
ARG TZ='Asia/Shanghai'
|
5 |
+
|
6 |
+
ARG CHATGPT_ON_WECHAT_VER
|
7 |
+
|
8 |
+
ENV BUILD_PREFIX=/app
|
9 |
+
|
10 |
+
RUN apt-get update \
|
11 |
+
&& apt-get install -y --no-install-recommends \
|
12 |
+
wget \
|
13 |
+
curl \
|
14 |
+
&& rm -rf /var/lib/apt/lists/* \
|
15 |
+
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
16 |
+
grep '"tag_name":' | \
|
17 |
+
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
18 |
+
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
19 |
+
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
20 |
+
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
21 |
+
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
22 |
+
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
23 |
+
&& cd ${BUILD_PREFIX} \
|
24 |
+
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
25 |
+
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
26 |
+
&& pip install --no-cache -r requirements.txt \
|
27 |
+
&& pip install --no-cache -r requirements-optional.txt
|
28 |
+
|
29 |
+
WORKDIR ${BUILD_PREFIX}
|
30 |
+
|
31 |
+
ADD ./entrypoint.sh /entrypoint.sh
|
32 |
+
|
33 |
+
RUN chmod +x /entrypoint.sh \
|
34 |
+
&& groupadd -r noroot \
|
35 |
+
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
36 |
+
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
37 |
+
|
38 |
+
USER noroot
|
39 |
+
|
40 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
docker/Dockerfile.debian.latest
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
LABEL maintainer="[email protected]"
|
4 |
+
ARG TZ='Asia/Shanghai'
|
5 |
+
|
6 |
+
ARG CHATGPT_ON_WECHAT_VER
|
7 |
+
|
8 |
+
ENV BUILD_PREFIX=/app
|
9 |
+
|
10 |
+
ADD . ${BUILD_PREFIX}
|
11 |
+
|
12 |
+
RUN apt-get update \
|
13 |
+
&&apt-get install -y --no-install-recommends bash \
|
14 |
+
ffmpeg espeak \
|
15 |
+
&& cd ${BUILD_PREFIX} \
|
16 |
+
&& cp config-template.json config.json \
|
17 |
+
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
18 |
+
&& pip install --no-cache -r requirements.txt \
|
19 |
+
&& pip install --no-cache -r requirements-optional.txt \
|
20 |
+
&& pip install azure-cognitiveservices-speech
|
21 |
+
|
22 |
+
WORKDIR ${BUILD_PREFIX}
|
23 |
+
|
24 |
+
ADD docker/entrypoint.sh /entrypoint.sh
|
25 |
+
|
26 |
+
RUN chmod +x /entrypoint.sh \
|
27 |
+
&& groupadd -r noroot \
|
28 |
+
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
29 |
+
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
30 |
+
|
31 |
+
USER noroot
|
32 |
+
|
33 |
+
ENTRYPOINT ["docker/entrypoint.sh"]
|
docker/Dockerfile.latest
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-alpine
|
2 |
+
|
3 |
+
LABEL maintainer="[email protected]"
|
4 |
+
ARG TZ='Asia/Shanghai'
|
5 |
+
|
6 |
+
ARG CHATGPT_ON_WECHAT_VER
|
7 |
+
|
8 |
+
ENV BUILD_PREFIX=/app
|
9 |
+
|
10 |
+
ADD . ${BUILD_PREFIX}
|
11 |
+
|
12 |
+
RUN apk add --no-cache bash ffmpeg espeak \
|
13 |
+
&& cd ${BUILD_PREFIX} \
|
14 |
+
&& cp config-template.json config.json \
|
15 |
+
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
16 |
+
&& pip install --no-cache -r requirements.txt \
|
17 |
+
&& pip install --no-cache -r requirements-optional.txt
|
18 |
+
|
19 |
+
WORKDIR ${BUILD_PREFIX}
|
20 |
+
|
21 |
+
ADD docker/entrypoint.sh /entrypoint.sh
|
22 |
+
|
23 |
+
RUN chmod +x /entrypoint.sh \
|
24 |
+
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
25 |
+
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
26 |
+
|
27 |
+
USER noroot
|
28 |
+
|
29 |
+
ENTRYPOINT ["docker/entrypoint.sh"]
|
docker/build.alpine.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# fetch latest release tag
|
4 |
+
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
5 |
+
grep '"tag_name":' | \
|
6 |
+
sed -E 's/.*"([^"]+)".*/\1/'`
|
7 |
+
|
8 |
+
# build image
|
9 |
+
docker build -f Dockerfile.alpine \
|
10 |
+
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
11 |
+
-t zhayujie/chatgpt-on-wechat .
|
12 |
+
|
13 |
+
# tag image
|
14 |
+
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
15 |
+
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
16 |
+
|