|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type OpenAIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type DeepSeekResponse struct {
|
|
Code int `json:"code"`
|
|
Msg string `json:"msg"`
|
|
Message string `json:"message"`
|
|
APISource string `json:"api_source"`
|
|
}
|
|
|
|
type OpenAIResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []Choice `json:"choices"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type StreamChoice struct {
|
|
Delta StreamMessage `json:"delta"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type StreamMessage struct {
|
|
Role string `json:"role,omitempty"`
|
|
Content string `json:"content,omitempty"`
|
|
}
|
|
|
|
type StreamResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []StreamChoice `json:"choices"`
|
|
}
|
|
|
|
var (
|
|
requestCount int64
|
|
requestLog []string
|
|
lastMinute time.Time
|
|
rpm int
|
|
logMutex sync.Mutex
|
|
)
|
|
|
|
func init() {
|
|
lastMinute = time.Now()
|
|
requestLog = make([]string, 0, 10000)
|
|
}
|
|
|
|
func main() {
|
|
http.HandleFunc("/", handleStats)
|
|
http.HandleFunc("/log", handleLogs)
|
|
http.HandleFunc("/hf/v1/chat/completions", handleChat)
|
|
log.Fatal(http.ListenAndServe(":7860", nil))
|
|
}
|
|
|
|
func handleStats(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
logMutex.Lock()
|
|
currentRPM := rpm
|
|
totalRequests := requestCount
|
|
logMutex.Unlock()
|
|
|
|
fmt.Fprintf(w, "总请求次数: %d\n每分钟请求数(RPM): %d", totalRequests, currentRPM)
|
|
}
|
|
|
|
func handleLogs(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.URL.Query().Get("auth")
|
|
if auth != "smnet" {
|
|
http.Error(w, "未授权访问", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
logMutex.Lock()
|
|
logs := make([]string, len(requestLog))
|
|
copy(logs, requestLog)
|
|
logMutex.Unlock()
|
|
|
|
for _, log := range logs {
|
|
fmt.Fprintln(w, log)
|
|
}
|
|
}
|
|
|
|
func handleChat(w http.ResponseWriter, r *http.Request) {
|
|
logMutex.Lock()
|
|
requestCount++
|
|
now := time.Now()
|
|
if now.Sub(lastMinute) >= time.Minute {
|
|
rpm = 1
|
|
lastMinute = now
|
|
} else {
|
|
rpm++
|
|
}
|
|
|
|
clientIP := r.Header.Get("X-Real-IP")
|
|
if clientIP == "" {
|
|
clientIP = r.Header.Get("X-Forwarded-For")
|
|
if clientIP == "" {
|
|
clientIP = r.RemoteAddr
|
|
}
|
|
}
|
|
|
|
logEntry := fmt.Sprintf("[%s] IP:%s 新请求处理", now.Format("2006-01-02 15:04:05"), clientIP)
|
|
if len(requestLog) >= 5000 {
|
|
requestLog = requestLog[1:]
|
|
}
|
|
requestLog = append(requestLog, logEntry)
|
|
logMutex.Unlock()
|
|
|
|
if r.Method != http.MethodPost {
|
|
log.Printf("错误: 不支持的请求方法 %s", r.Method)
|
|
http.Error(w, "仅支持 POST 请求", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Printf("错误: 读取请求失败 - %v", err)
|
|
http.Error(w, "读取请求失败", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var openAIReq OpenAIRequest
|
|
if err := json.Unmarshal(body, &openAIReq); err != nil {
|
|
log.Printf("错误: 请求格式错误 - %v", err)
|
|
http.Error(w, "请求格式错误", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
log.Printf("用户问题: %s", openAIReq.Messages[len(openAIReq.Messages)-1].Content)
|
|
|
|
var apiURL string
|
|
var modelName string
|
|
switch openAIReq.Model {
|
|
case "deepseek-r1":
|
|
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions"
|
|
modelName = "deepseek-ai/DeepSeek-R1"
|
|
default:
|
|
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions"
|
|
modelName = "deepseek-ai/DeepSeek-V3"
|
|
}
|
|
|
|
deepseekReq := map[string]interface{}{
|
|
"messages": openAIReq.Messages,
|
|
"stream": true,
|
|
"model": modelName,
|
|
}
|
|
|
|
deepseekBody, err := json.Marshal(deepseekReq)
|
|
if err != nil {
|
|
log.Printf("错误: 构造请求失败 - %v", err)
|
|
http.Error(w, "构造请求失败", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
maxRetries := 200
|
|
var tryRequest func() (string, error)
|
|
|
|
tryRequest = func() (string, error) {
|
|
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(deepseekBody))
|
|
if err != nil {
|
|
return "", fmt.Errorf("创建请求失败: %v", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("请求失败: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var fullMessage string
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if openAIReq.Stream {
|
|
_, err = fmt.Fprintf(w, "%s\n", line)
|
|
if err != nil {
|
|
return "", fmt.Errorf("写入流式响应失败: %v", err)
|
|
}
|
|
w.(http.Flusher).Flush()
|
|
}
|
|
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var streamResp StreamResponse
|
|
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" {
|
|
fullMessage += streamResp.Choices[0].Delta.Content
|
|
}
|
|
}
|
|
|
|
if fullMessage == "" {
|
|
return "", fmt.Errorf("收到空回复")
|
|
}
|
|
|
|
return fullMessage, nil
|
|
}
|
|
|
|
var fullMessage string
|
|
var lastError error
|
|
|
|
for i := 0; i < maxRetries; i++ {
|
|
fullMessage, lastError = tryRequest()
|
|
if lastError == nil {
|
|
break
|
|
}
|
|
log.Printf("第 %d 次尝试失败: %v,准备重试", i+1, lastError)
|
|
time.Sleep(time.Second * time.Duration(i+1))
|
|
}
|
|
|
|
if lastError != nil {
|
|
log.Printf("错误: 所有重试都失败 - %v", lastError)
|
|
http.Error(w, "服务暂时不可用", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
log.Printf("AI回答: %s", fullMessage)
|
|
|
|
if !openAIReq.Stream {
|
|
openAIResp := OpenAIResponse{
|
|
ID: "chatcmpl-" + generateRandomString(10),
|
|
Object: "chat.completion",
|
|
Created: getCurrentTimestamp(),
|
|
Model: openAIReq.Model,
|
|
Choices: []Choice{
|
|
{
|
|
Index: 0,
|
|
Message: Message{
|
|
Role: "assistant",
|
|
Content: fullMessage,
|
|
},
|
|
FinishReason: "stop",
|
|
},
|
|
},
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(openAIResp)
|
|
}
|
|
}
|
|
|
|
func generateRandomString(length int) string {
|
|
return "SomeApiResponse"
|
|
}
|
|
|
|
func getCurrentTimestamp() int64 {
|
|
return time.Now().Unix()
|
|
}
|
|
|
|
func writeSSE(w http.ResponseWriter, data interface{}) error {
|
|
jsonData, err := json.Marshal(data)
|
|
if err != nil {
|
|
http.Error(w, "JSON编码失败", http.StatusInternalServerError)
|
|
return err
|
|
}
|
|
|
|
_, err = fmt.Fprintf(w, "data: %s\n\n", jsonData)
|
|
if err != nil {
|
|
http.Error(w, "写入响应失败", http.StatusInternalServerError)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|