|
package main |
|
|
|
import ( |
|
"bufio" |
|
"bytes" |
|
"encoding/json" |
|
"fmt" |
|
"io/ioutil" |
|
"log" |
|
"net/http" |
|
"strings" |
|
"sync" |
|
"time" |
|
) |
|
|
|
type OpenAIRequest struct { |
|
Model string `json:"model"` |
|
Messages []Message `json:"messages"` |
|
Stream bool `json:"stream"` |
|
} |
|
|
|
type Message struct { |
|
Role string `json:"role"` |
|
Content string `json:"content"` |
|
} |
|
|
|
type DeepSeekResponse struct { |
|
Code int `json:"code"` |
|
Msg string `json:"msg"` |
|
Message string `json:"message"` |
|
APISource string `json:"api_source"` |
|
} |
|
|
|
type OpenAIResponse struct { |
|
ID string `json:"id"` |
|
Object string `json:"object"` |
|
Created int64 `json:"created"` |
|
Model string `json:"model"` |
|
Choices []Choice `json:"choices"` |
|
} |
|
|
|
type Choice struct { |
|
Index int `json:"index"` |
|
Message Message `json:"message"` |
|
FinishReason string `json:"finish_reason"` |
|
} |
|
|
|
type StreamChoice struct { |
|
Delta StreamMessage `json:"delta"` |
|
Index int `json:"index"` |
|
FinishReason *string `json:"finish_reason"` |
|
} |
|
|
|
type StreamMessage struct { |
|
Role string `json:"role,omitempty"` |
|
Content string `json:"content,omitempty"` |
|
} |
|
|
|
type StreamResponse struct { |
|
ID string `json:"id"` |
|
Object string `json:"object"` |
|
Created int64 `json:"created"` |
|
Model string `json:"model"` |
|
Choices []StreamChoice `json:"choices"` |
|
} |
|
|
|
var ( |
|
requestCount int64 |
|
requestLog []string |
|
lastMinute time.Time |
|
rpm int |
|
logMutex sync.Mutex |
|
) |
|
|
|
func init() { |
|
lastMinute = time.Now() |
|
requestLog = make([]string, 0, 10000) |
|
} |
|
|
|
func main() { |
|
http.HandleFunc("/", handleStats) |
|
http.HandleFunc("/log", handleLogs) |
|
http.HandleFunc("/hf/v1/chat/completions", handleChat) |
|
log.Fatal(http.ListenAndServe(":7860", nil)) |
|
} |
|
|
|
func handleStats(w http.ResponseWriter, r *http.Request) { |
|
if r.URL.Path != "/" { |
|
http.NotFound(w, r) |
|
return |
|
} |
|
|
|
logMutex.Lock() |
|
currentRPM := rpm |
|
totalRequests := requestCount |
|
logMutex.Unlock() |
|
|
|
fmt.Fprintf(w, "总请求次数: %d\n每分钟请求数(RPM): %d", totalRequests, currentRPM) |
|
} |
|
|
|
func handleLogs(w http.ResponseWriter, r *http.Request) { |
|
auth := r.URL.Query().Get("auth") |
|
if auth != "smnet" { |
|
http.Error(w, "未授权访问", http.StatusUnauthorized) |
|
return |
|
} |
|
|
|
logMutex.Lock() |
|
logs := make([]string, len(requestLog)) |
|
copy(logs, requestLog) |
|
logMutex.Unlock() |
|
|
|
for _, log := range logs { |
|
fmt.Fprintln(w, log) |
|
} |
|
} |
|
|
|
func handleChat(w http.ResponseWriter, r *http.Request) { |
|
logMutex.Lock() |
|
requestCount++ |
|
now := time.Now() |
|
if now.Sub(lastMinute) >= time.Minute { |
|
rpm = 1 |
|
lastMinute = now |
|
} else { |
|
rpm++ |
|
} |
|
|
|
clientIP := r.Header.Get("X-Real-IP") |
|
if clientIP == "" { |
|
clientIP = r.Header.Get("X-Forwarded-For") |
|
if clientIP == "" { |
|
clientIP = r.RemoteAddr |
|
} |
|
} |
|
|
|
logEntry := fmt.Sprintf("[%s] IP:%s 新请求处理", now.Format("2006-01-02 15:04:05"), clientIP) |
|
if len(requestLog) >= 5000 { |
|
requestLog = requestLog[1:] |
|
} |
|
requestLog = append(requestLog, logEntry) |
|
logMutex.Unlock() |
|
|
|
if r.Method != http.MethodPost { |
|
log.Printf("错误: 不支持的请求方法 %s", r.Method) |
|
http.Error(w, "仅支持 POST 请求", http.StatusMethodNotAllowed) |
|
return |
|
} |
|
|
|
body, err := ioutil.ReadAll(r.Body) |
|
if err != nil { |
|
log.Printf("错误: 读取请求失败 - %v", err) |
|
http.Error(w, "读取请求失败", http.StatusBadRequest) |
|
return |
|
} |
|
|
|
var openAIReq OpenAIRequest |
|
if err := json.Unmarshal(body, &openAIReq); err != nil { |
|
log.Printf("错误: 请求格式错误 - %v", err) |
|
http.Error(w, "请求格式错误", http.StatusBadRequest) |
|
return |
|
} |
|
|
|
log.Printf("用户问题: %s", openAIReq.Messages[len(openAIReq.Messages)-1].Content) |
|
|
|
var apiURL string |
|
var modelName string |
|
switch openAIReq.Model { |
|
case "deepseek-r1": |
|
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions" |
|
modelName = "deepseek-ai/DeepSeek-R1" |
|
default: |
|
apiURL = "https://api.deepinfra.com/v1/openai/chat/completions" |
|
modelName = "deepseek-ai/DeepSeek-V3" |
|
} |
|
|
|
deepseekReq := map[string]interface{}{ |
|
"messages": openAIReq.Messages, |
|
"stream": true, |
|
"model": modelName, |
|
} |
|
|
|
deepseekBody, err := json.Marshal(deepseekReq) |
|
if err != nil { |
|
log.Printf("错误: 构造请求失败 - %v", err) |
|
http.Error(w, "构造请求失败", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
maxRetries := 10 |
|
var tryRequest func() (string, error) |
|
|
|
tryRequest = func() (string, error) { |
|
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(deepseekBody)) |
|
if err != nil { |
|
return "", fmt.Errorf("创建请求失败: %v", err) |
|
} |
|
req.Header.Set("Content-Type", "application/json") |
|
|
|
client := &http.Client{} |
|
resp, err := client.Do(req) |
|
if err != nil { |
|
return "", fmt.Errorf("请求失败: %v", err) |
|
} |
|
defer resp.Body.Close() |
|
|
|
var fullMessage string |
|
scanner := bufio.NewScanner(resp.Body) |
|
|
|
for scanner.Scan() { |
|
line := scanner.Text() |
|
if openAIReq.Stream { |
|
_, err = fmt.Fprintf(w, "%s\n", line) |
|
if err != nil { |
|
return "", fmt.Errorf("写入流式响应失败: %v", err) |
|
} |
|
w.(http.Flusher).Flush() |
|
} |
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
|
|
data := strings.TrimPrefix(line, "data: ") |
|
if data == "[DONE]" { |
|
break |
|
} |
|
|
|
var streamResp StreamResponse |
|
if err := json.Unmarshal([]byte(data), &streamResp); err != nil { |
|
continue |
|
} |
|
|
|
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" { |
|
fullMessage += streamResp.Choices[0].Delta.Content |
|
} |
|
} |
|
|
|
if fullMessage == "" { |
|
return "", fmt.Errorf("收到空回复") |
|
} |
|
|
|
return fullMessage, nil |
|
} |
|
|
|
var fullMessage string |
|
var lastError error |
|
|
|
for i := 0; i < maxRetries; i++ { |
|
fullMessage, lastError = tryRequest() |
|
if lastError == nil { |
|
break |
|
} |
|
log.Printf("第 %d 次尝试失败: %v,准备重试", i+1, lastError) |
|
time.Sleep(time.Second * time.Duration(i+1)) |
|
} |
|
|
|
if lastError != nil { |
|
log.Printf("错误: 所有重试都失败 - %v", lastError) |
|
http.Error(w, "服务暂时不可用", http.StatusServiceUnavailable) |
|
return |
|
} |
|
|
|
log.Printf("AI回答: %s", fullMessage) |
|
|
|
if !openAIReq.Stream { |
|
openAIResp := OpenAIResponse{ |
|
ID: "chatcmpl-" + generateRandomString(10), |
|
Object: "chat.completion", |
|
Created: getCurrentTimestamp(), |
|
Model: openAIReq.Model, |
|
Choices: []Choice{ |
|
{ |
|
Index: 0, |
|
Message: Message{ |
|
Role: "assistant", |
|
Content: fullMessage, |
|
}, |
|
FinishReason: "stop", |
|
}, |
|
}, |
|
} |
|
|
|
w.Header().Set("Content-Type", "application/json") |
|
json.NewEncoder(w).Encode(openAIResp) |
|
} |
|
} |
|
|
|
func generateRandomString(length int) string { |
|
return "SomeApiResponse" |
|
} |
|
|
|
func getCurrentTimestamp() int64 { |
|
return time.Now().Unix() |
|
} |
|
|
|
func writeSSE(w http.ResponseWriter, data interface{}) error { |
|
jsonData, err := json.Marshal(data) |
|
if err != nil { |
|
http.Error(w, "JSON编码失败", http.StatusInternalServerError) |
|
return err |
|
} |
|
|
|
_, err = fmt.Fprintf(w, "data: %s\n\n", jsonData) |
|
if err != nil { |
|
http.Error(w, "写入响应失败", http.StatusInternalServerError) |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|