aiproxy/main.go

105 lines
2.7 KiB
Go

package main
import (
"bytes"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
jsoniter "github.com/json-iterator/go"
"log"
"math/rand"
"net/http"
"slices"
"strconv"
"time"
)
type ModelMsg struct {
Model string `json:"model"`
}
func main() {
InitConfig()
app := fiber.New(fiber.Config{
StreamRequestBody: true,
DisablePreParseMultipartForm: true,
ReduceMemoryUsage: true,
JSONEncoder: jsoniter.Marshal,
JSONDecoder: jsoniter.Unmarshal,
})
app.Use(cors.New(cors.ConfigDefault))
app.Get("/ping", ping)
app.Get("/version", version)
app.All("/v1/*", proxy)
log.Fatal(app.Listen("0.0.0.0:8080"))
}
func pickTarget(all []AiTarget) *AiTarget {
if len(all) < 1 {
return nil
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
n := r.Intn(len(all))
return &all[n]
}
func proxy(c *fiber.Ctx) (fErr error) {
defer func() {
if err := recover(); err != nil {
if e, ok := err.(error); ok {
fErr = e
}
}
}()
author := c.Get("Authorization")
if author == "" {
panic(fiber.NewError(fiber.StatusUnauthorized, "missing token"))
}
if author[:7] != "Bearer " {
panic(fiber.NewError(fiber.StatusUnauthorized, "invalid token"))
}
token := author[7:]
if !slices.Contains(config.InKeys, token) {
panic(fiber.NewError(fiber.StatusForbidden, "invalid token"))
}
raw := c.BodyRaw()
if len(raw) == 0 {
panic(fiber.NewError(fiber.StatusBadRequest, "empty body"))
}
var msg ModelMsg
if err := jsoniter.Unmarshal(raw, &msg); err != nil {
panic(fiber.NewError(fiber.StatusBadRequest, err.Error()))
}
targets, ok := config.ApiPools[msg.Model]
if !ok {
panic(fiber.NewError(fiber.StatusBadRequest, "invalid model"))
}
t := pickTarget(targets)
if t == nil {
panic(fiber.NewError(fiber.StatusInternalServerError, "no target"))
}
url := t.TargetURL + c.OriginalURL()
req, _ := http.NewRequest(c.Method(), url, bytes.NewReader(raw))
req.Header.Set("Authorization", "Bearer "+t.OpenAIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
panic(err)
}
c.Context().SetContentType(resp.Header.Get("Content-Type"))
if resp.Header.Get("Transfer-Encoding") != "" {
c.Set("Transfer-Encoding", resp.Header.Get("Transfer-Encoding"))
} else if resp.ContentLength != 0 {
c.Set("Content-Length", strconv.Itoa(int(resp.ContentLength)))
}
log.Printf("api: %s, model: %s, key: %s, status: %d", t.TargetURL, msg.Model, t.OpenAIKey, resp.StatusCode)
return c.SendStream(resp.Body, -1)
}
func version(c *fiber.Ctx) error {
return c.SendString("v0.1.0")
}
func ping(c *fiber.Ctx) error {
return c.SendStatus(200)
}