DeepInfra-163213 / main.go
9um3yhdu's picture
Update main.go
df15cb8 verified
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"sync"
"time"
)
type OpenAIRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type DeepSeekResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Message string `json:"message"`
APISource string `json:"api_source"`
}
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
}
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type StreamChoice struct {
Delta StreamMessage `json:"delta"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type StreamMessage struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type StreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
var (
requestCount int64
requestLog []string
lastMinute time.Time
rpm int
logMutex sync.Mutex
)
func init() {
lastMinute = time.Now()
requestLog = make([]string, 0, 10000)
}
func main() {
http.HandleFunc("/", handleStats)
http.HandleFunc("/log", handleLogs)
http.HandleFunc("/hf/v1/chat/completions", handleChat)
log.Fatal(http.ListenAndServe(":7860", nil))
}
func handleStats(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
logMutex.Lock()
currentRPM := rpm
totalRequests := requestCount
logMutex.Unlock()
fmt.Fprintf(w, "总请求次数: %d\n每分钟请求数(RPM): %d", totalRequests, currentRPM)
}
func handleLogs(w http.ResponseWriter, r *http.Request) {
auth := r.URL.Query().Get("auth")
if auth != "smnet" {
http.Error(w, "未授权访问", http.StatusUnauthorized)
return
}
logMutex.Lock()
logs := make([]string, len(requestLog))
copy(logs, requestLog)
logMutex.Unlock()
for _, log := range logs {
fmt.Fprintln(w, log)
}
}
func handleChat(w http.ResponseWriter, r *http.Request) {
logMutex.Lock()
requestCount++
now := time.Now()
if now.Sub(lastMinute) >= time.Minute {
rpm = 1
lastMinute = now
} else {
rpm++
}
clientIP := r.Header.Get("X-Real-IP")
if clientIP == "" {
clientIP = r.Header.Get("X-Forwarded-For")
if clientIP == "" {
clientIP = r.RemoteAddr
}
}
logEntry := fmt.Sprintf("[%s] IP:%s 新请求处理", now.Format("2006-01-02 15:04:05"), clientIP)
if len(requestLog) >= 5000 {
requestLog = requestLog[1:]
}
requestLog = append(requestLog, logEntry)
logMutex.Unlock()
if r.Method != http.MethodPost {
log.Printf("错误: 不支持的请求方法 %s", r.Method)
http.Error(w, "仅支持 POST 请求", http.StatusMethodNotAllowed)
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Printf("错误: 读取请求失败 - %v", err)
http.Error(w, "读取请求失败", http.StatusBadRequest)
return
}
var openAIReq OpenAIRequest
if err := json.Unmarshal(body, &openAIReq); err != nil {
log.Printf("错误: 请求格式错误 - %v", err)
http.Error(w, "请求格式错误", http.StatusBadRequest)
return
}
log.Printf("用户问题: %s", openAIReq.Messages[len(openAIReq.Messages)-1].Content)
var apiURL string
var modelName string
switch openAIReq.Model {
case "deepseek-r1":
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions"
modelName = "deepseek-ai/DeepSeek-R1"
default:
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions"
modelName = "deepseek-ai/DeepSeek-V3"
}
deepseekReq := map[string]interface{}{
"messages": openAIReq.Messages,
"stream": true,
"model": modelName,
}
deepseekBody, err := json.Marshal(deepseekReq)
if err != nil {
log.Printf("错误: 构造请求失败 - %v", err)
http.Error(w, "构造请求失败", http.StatusInternalServerError)
return
}
maxRetries := 10
var tryRequest func() (string, error)
tryRequest = func() (string, error) {
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(deepseekBody))
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
var fullMessage string
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if openAIReq.Stream {
_, err = fmt.Fprintf(w, "%s\n", line)
if err != nil {
return "", fmt.Errorf("写入流式响应失败: %v", err)
}
w.(http.Flusher).Flush()
}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var streamResp StreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" {
fullMessage += streamResp.Choices[0].Delta.Content
}
}
if fullMessage == "" {
return "", fmt.Errorf("收到空回复")
}
return fullMessage, nil
}
var fullMessage string
var lastError error
for i := 0; i < maxRetries; i++ {
fullMessage, lastError = tryRequest()
if lastError == nil {
break
}
log.Printf("第 %d 次尝试失败: %v,准备重试", i+1, lastError)
time.Sleep(time.Second * time.Duration(i+1))
}
if lastError != nil {
log.Printf("错误: 所有重试都失败 - %v", lastError)
http.Error(w, "服务暂时不可用", http.StatusServiceUnavailable)
return
}
log.Printf("AI回答: %s", fullMessage)
if !openAIReq.Stream {
openAIResp := OpenAIResponse{
ID: "chatcmpl-" + generateRandomString(10),
Object: "chat.completion",
Created: getCurrentTimestamp(),
Model: openAIReq.Model,
Choices: []Choice{
{
Index: 0,
Message: Message{
Role: "assistant",
Content: fullMessage,
},
FinishReason: "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(openAIResp)
}
}
func generateRandomString(length int) string {
return "SomeApiResponse"
}
func getCurrentTimestamp() int64 {
return time.Now().Unix()
}
func writeSSE(w http.ResponseWriter, data interface{}) error {
jsonData, err := json.Marshal(data)
if err != nil {
http.Error(w, "JSON编码失败", http.StatusInternalServerError)
return err
}
_, err = fmt.Fprintf(w, "data: %s\n\n", jsonData)
if err != nil {
http.Error(w, "写入响应失败", http.StatusInternalServerError)
return err
}
return nil
}