Spaces:
Running
Running
package main | |
import ( | |
"bufio" | |
"bytes" | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"net/http" | |
"strings" | |
"sync" | |
"time" | |
) | |
type OpenAIRequest struct { | |
Model string `json:"model"` | |
Messages []Message `json:"messages"` | |
Stream bool `json:"stream"` | |
} | |
type Message struct { | |
Role string `json:"role"` | |
Content string `json:"content"` | |
} | |
type DeepSeekResponse struct { | |
Code int `json:"code"` | |
Msg string `json:"msg"` | |
Message string `json:"message"` | |
APISource string `json:"api_source"` | |
} | |
type OpenAIResponse struct { | |
ID string `json:"id"` | |
Object string `json:"object"` | |
Created int64 `json:"created"` | |
Model string `json:"model"` | |
Choices []Choice `json:"choices"` | |
} | |
type Choice struct { | |
Index int `json:"index"` | |
Message Message `json:"message"` | |
FinishReason string `json:"finish_reason"` | |
} | |
type StreamChoice struct { | |
Delta StreamMessage `json:"delta"` | |
Index int `json:"index"` | |
FinishReason *string `json:"finish_reason"` | |
} | |
type StreamMessage struct { | |
Role string `json:"role,omitempty"` | |
Content string `json:"content,omitempty"` | |
} | |
type StreamResponse struct { | |
ID string `json:"id"` | |
Object string `json:"object"` | |
Created int64 `json:"created"` | |
Model string `json:"model"` | |
Choices []StreamChoice `json:"choices"` | |
} | |
var ( | |
requestCount int64 | |
requestLog []string | |
lastMinute time.Time | |
rpm int | |
logMutex sync.Mutex | |
) | |
func init() { | |
lastMinute = time.Now() | |
requestLog = make([]string, 0, 10000) | |
} | |
func main() { | |
http.HandleFunc("/", handleStats) | |
http.HandleFunc("/log", handleLogs) | |
http.HandleFunc("/hf/v1/chat/completions", handleChat) | |
log.Fatal(http.ListenAndServe(":7860", nil)) | |
} | |
func handleStats(w http.ResponseWriter, r *http.Request) { | |
if r.URL.Path != "/" { | |
http.NotFound(w, r) | |
return | |
} | |
logMutex.Lock() | |
currentRPM := rpm | |
totalRequests := requestCount | |
logMutex.Unlock() | |
fmt.Fprintf(w, "总请求次数: %d\n每分钟请求数(RPM): %d", totalRequests, currentRPM) | |
} | |
func handleLogs(w http.ResponseWriter, r *http.Request) { | |
auth := r.URL.Query().Get("auth") | |
if auth != "smnet" { | |
http.Error(w, "未授权访问", http.StatusUnauthorized) | |
return | |
} | |
logMutex.Lock() | |
logs := make([]string, len(requestLog)) | |
copy(logs, requestLog) | |
logMutex.Unlock() | |
for _, log := range logs { | |
fmt.Fprintln(w, log) | |
} | |
} | |
func handleChat(w http.ResponseWriter, r *http.Request) { | |
logMutex.Lock() | |
requestCount++ | |
now := time.Now() | |
if now.Sub(lastMinute) >= time.Minute { | |
rpm = 1 | |
lastMinute = now | |
} else { | |
rpm++ | |
} | |
clientIP := r.Header.Get("X-Real-IP") | |
if clientIP == "" { | |
clientIP = r.Header.Get("X-Forwarded-For") | |
if clientIP == "" { | |
clientIP = r.RemoteAddr | |
} | |
} | |
logEntry := fmt.Sprintf("[%s] IP:%s 新请求处理", now.Format("2006-01-02 15:04:05"), clientIP) | |
if len(requestLog) >= 5000 { | |
requestLog = requestLog[1:] | |
} | |
requestLog = append(requestLog, logEntry) | |
logMutex.Unlock() | |
if r.Method != http.MethodPost { | |
log.Printf("错误: 不支持的请求方法 %s", r.Method) | |
http.Error(w, "仅支持 POST 请求", http.StatusMethodNotAllowed) | |
return | |
} | |
body, err := ioutil.ReadAll(r.Body) | |
if err != nil { | |
log.Printf("错误: 读取请求失败 - %v", err) | |
http.Error(w, "读取请求失败", http.StatusBadRequest) | |
return | |
} | |
var openAIReq OpenAIRequest | |
if err := json.Unmarshal(body, &openAIReq); err != nil { | |
log.Printf("错误: 请求格式错误 - %v", err) | |
http.Error(w, "请求格式错误", http.StatusBadRequest) | |
return | |
} | |
log.Printf("用户问题: %s", openAIReq.Messages[len(openAIReq.Messages)-1].Content) | |
var apiURL string | |
var modelName string | |
switch openAIReq.Model { | |
case "deepseek-r1": | |
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions" | |
modelName = "deepseek-ai/DeepSeek-R1" | |
default: | |
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions" | |
modelName = "deepseek-ai/DeepSeek-V3" | |
} | |
deepseekReq := map[string]interface{}{ | |
"messages": openAIReq.Messages, | |
"stream": true, | |
"model": modelName, | |
} | |
deepseekBody, err := json.Marshal(deepseekReq) | |
if err != nil { | |
log.Printf("错误: 构造请求失败 - %v", err) | |
http.Error(w, "构造请求失败", http.StatusInternalServerError) | |
return | |
} | |
maxRetries := 10 | |
var tryRequest func() (string, error) | |
tryRequest = func() (string, error) { | |
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(deepseekBody)) | |
if err != nil { | |
return "", fmt.Errorf("创建请求失败: %v", err) | |
} | |
req.Header.Set("Content-Type", "application/json") | |
client := &http.Client{} | |
resp, err := client.Do(req) | |
if err != nil { | |
return "", fmt.Errorf("请求失败: %v", err) | |
} | |
defer resp.Body.Close() | |
var fullMessage string | |
scanner := bufio.NewScanner(resp.Body) | |
for scanner.Scan() { | |
line := scanner.Text() | |
if openAIReq.Stream { | |
_, err = fmt.Fprintf(w, "%s\n", line) | |
if err != nil { | |
return "", fmt.Errorf("写入流式响应失败: %v", err) | |
} | |
w.(http.Flusher).Flush() | |
} | |
if !strings.HasPrefix(line, "data: ") { | |
continue | |
} | |
data := strings.TrimPrefix(line, "data: ") | |
if data == "[DONE]" { | |
break | |
} | |
var streamResp StreamResponse | |
if err := json.Unmarshal([]byte(data), &streamResp); err != nil { | |
continue | |
} | |
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" { | |
fullMessage += streamResp.Choices[0].Delta.Content | |
} | |
} | |
if fullMessage == "" { | |
return "", fmt.Errorf("收到空回复") | |
} | |
return fullMessage, nil | |
} | |
var fullMessage string | |
var lastError error | |
for i := 0; i < maxRetries; i++ { | |
fullMessage, lastError = tryRequest() | |
if lastError == nil { | |
break | |
} | |
log.Printf("第 %d 次尝试失败: %v,准备重试", i+1, lastError) | |
time.Sleep(time.Second * time.Duration(i+1)) | |
} | |
if lastError != nil { | |
log.Printf("错误: 所有重试都失败 - %v", lastError) | |
http.Error(w, "服务暂时不可用", http.StatusServiceUnavailable) | |
return | |
} | |
log.Printf("AI回答: %s", fullMessage) | |
if !openAIReq.Stream { | |
openAIResp := OpenAIResponse{ | |
ID: "chatcmpl-" + generateRandomString(10), | |
Object: "chat.completion", | |
Created: getCurrentTimestamp(), | |
Model: openAIReq.Model, | |
Choices: []Choice{ | |
{ | |
Index: 0, | |
Message: Message{ | |
Role: "assistant", | |
Content: fullMessage, | |
}, | |
FinishReason: "stop", | |
}, | |
}, | |
} | |
w.Header().Set("Content-Type", "application/json") | |
json.NewEncoder(w).Encode(openAIResp) | |
} | |
} | |
func generateRandomString(length int) string { | |
return "SomeApiResponse" | |
} | |
func getCurrentTimestamp() int64 { | |
return time.Now().Unix() | |
} | |
func writeSSE(w http.ResponseWriter, data interface{}) error { | |
jsonData, err := json.Marshal(data) | |
if err != nil { | |
http.Error(w, "JSON编码失败", http.StatusInternalServerError) | |
return err | |
} | |
_, err = fmt.Fprintf(w, "data: %s\n\n", jsonData) | |
if err != nil { | |
http.Error(w, "写入响应失败", http.StatusInternalServerError) | |
return err | |
} | |
return nil | |
} | |