105 lines
2.7 KiB
Go
105 lines
2.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
|
jsoniter "github.com/json-iterator/go"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type ModelMsg struct {
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
func main() {
|
|
InitConfig()
|
|
app := fiber.New(fiber.Config{
|
|
StreamRequestBody: true,
|
|
DisablePreParseMultipartForm: true,
|
|
ReduceMemoryUsage: true,
|
|
JSONEncoder: jsoniter.Marshal,
|
|
JSONDecoder: jsoniter.Unmarshal,
|
|
})
|
|
app.Use(cors.New(cors.ConfigDefault))
|
|
app.Get("/ping", ping)
|
|
app.Get("/version", version)
|
|
app.All("/v1/*", proxy)
|
|
log.Fatal(app.Listen("0.0.0.0:8080"))
|
|
}
|
|
|
|
func pickTarget(all []AiTarget) *AiTarget {
|
|
if len(all) < 1 {
|
|
return nil
|
|
}
|
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
n := r.Intn(len(all))
|
|
return &all[n]
|
|
}
|
|
|
|
func proxy(c *fiber.Ctx) (fErr error) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
if e, ok := err.(error); ok {
|
|
fErr = e
|
|
}
|
|
}
|
|
}()
|
|
author := c.Get("Authorization")
|
|
if author == "" {
|
|
panic(fiber.NewError(fiber.StatusUnauthorized, "missing token"))
|
|
}
|
|
if author[:7] != "Bearer " {
|
|
panic(fiber.NewError(fiber.StatusUnauthorized, "invalid token"))
|
|
}
|
|
token := author[7:]
|
|
if !slices.Contains(config.InKeys, token) {
|
|
panic(fiber.NewError(fiber.StatusForbidden, "invalid token"))
|
|
}
|
|
raw := c.BodyRaw()
|
|
if len(raw) == 0 {
|
|
panic(fiber.NewError(fiber.StatusBadRequest, "empty body"))
|
|
}
|
|
var msg ModelMsg
|
|
if err := jsoniter.Unmarshal(raw, &msg); err != nil {
|
|
panic(fiber.NewError(fiber.StatusBadRequest, err.Error()))
|
|
}
|
|
targets, ok := config.ApiPools[msg.Model]
|
|
if !ok {
|
|
panic(fiber.NewError(fiber.StatusBadRequest, "invalid model"))
|
|
}
|
|
t := pickTarget(targets)
|
|
if t == nil {
|
|
panic(fiber.NewError(fiber.StatusInternalServerError, "no target"))
|
|
}
|
|
url := t.TargetURL + c.OriginalURL()
|
|
req, _ := http.NewRequest(c.Method(), url, bytes.NewReader(raw))
|
|
req.Header.Set("Authorization", "Bearer "+t.OpenAIKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
c.Context().SetContentType(resp.Header.Get("Content-Type"))
|
|
if resp.Header.Get("Transfer-Encoding") != "" {
|
|
c.Set("Transfer-Encoding", resp.Header.Get("Transfer-Encoding"))
|
|
} else if resp.ContentLength != 0 {
|
|
c.Set("Content-Length", strconv.Itoa(int(resp.ContentLength)))
|
|
}
|
|
log.Printf("api: %s, model: %s, key: %s, status: %d", t.TargetURL, msg.Model, t.OpenAIKey, resp.StatusCode)
|
|
|
|
return c.SendStream(resp.Body, -1)
|
|
}
|
|
|
|
func version(c *fiber.Ctx) error {
|
|
return c.SendString("v0.1.0")
|
|
}
|
|
func ping(c *fiber.Ctx) error {
|
|
return c.SendStatus(200)
|
|
}
|