147 lines
3.2 KiB
Go
147 lines
3.2 KiB
Go
package krpc
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"krwu.top/krpc.v1/codec"
|
|
)
|
|
|
|
type Server struct {
|
|
serviceMap sync.Map
|
|
}
|
|
|
|
func NewServer() *Server {
|
|
return &Server{}
|
|
}
|
|
|
|
var DefaultServer = NewServer()
|
|
|
|
func (s *Server) Accept(lis net.Listener) {
|
|
for {
|
|
conn, err := lis.Accept()
|
|
if err != nil {
|
|
fmt.Println("rpc server: accept error: ", err)
|
|
return
|
|
}
|
|
go s.ServeConn(conn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
var opts Options
|
|
if err := json.NewDecoder(conn).Decode(&opts); err != nil {
|
|
fmt.Println("rpc server: options error: ", err)
|
|
return
|
|
}
|
|
if opts.MagicNumber != codec.MagicNumber {
|
|
fmt.Printf("rpc server: invalid magic number %x\n", opts.MagicNumber)
|
|
return
|
|
}
|
|
f := codec.NewCodecFuncMap[opts.CodecType]
|
|
if f == nil {
|
|
fmt.Printf("rpc server: invalid codec type %s\n", opts.CodecType)
|
|
return
|
|
}
|
|
s.ServeCodec(f(conn))
|
|
}
|
|
|
|
var invalidRequest = struct{}{}
|
|
|
|
func (s *Server) ServeCodec(cc codec.Codec) {
|
|
sending := new(sync.Mutex)
|
|
wg := new(sync.WaitGroup)
|
|
for {
|
|
req, err := s.readRequest(cc)
|
|
if err != nil {
|
|
fmt.Println("rpc server invalid request: ", err)
|
|
if req == nil {
|
|
break
|
|
}
|
|
req.h.Error = err.Error()
|
|
s.sendResponse(cc, req.h, invalidRequest, sending)
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
go s.handleRequest(cc, req, sending, wg, time.Second*3)
|
|
}
|
|
wg.Wait()
|
|
_ = cc.Close()
|
|
}
|
|
|
|
func (s *Server) Register(rcvr interface{}) error {
|
|
svc := newService(rcvr)
|
|
if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup {
|
|
return fmt.Errorf("rpc: service already defined: %s\n", svc.name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
|
|
dot := strings.LastIndex(serviceMethod, ".")
|
|
if dot < 0 {
|
|
err = fmt.Errorf("rpc: service/method request ill-formed: %s", serviceMethod)
|
|
return
|
|
}
|
|
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
|
|
svci, ok := s.serviceMap.Load(serviceName)
|
|
if !ok {
|
|
err = fmt.Errorf("rpc: server can't find service %s", serviceName)
|
|
return
|
|
}
|
|
svc = svci.(*service)
|
|
mtype = svc.method[methodName]
|
|
if mtype == nil {
|
|
err = fmt.Errorf("rpc: server can't find method %s", methodName)
|
|
}
|
|
return
|
|
}
|
|
|
|
func Register(rcvr interface{}) error {
|
|
return DefaultServer.Register(rcvr)
|
|
}
|
|
|
|
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
|
|
|
|
const (
|
|
connected = "200 Connected to kRPC"
|
|
defaultRPCPath = "/_krpc_"
|
|
defaultDebugPath = "/debug/krpc"
|
|
)
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodConnect {
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
_, _ = fmt.Fprintln(w, "405 method not allowed")
|
|
return
|
|
}
|
|
conn, _, err := w.(http.Hijacker).Hijack()
|
|
if err != nil {
|
|
log.Printf("rpc hijacking %s: %s", r.RemoteAddr, err.Error())
|
|
return
|
|
}
|
|
_, _ = fmt.Fprintf(conn, "HTTP/1.0 %s\n\n", connected)
|
|
s.ServeConn(conn)
|
|
}
|
|
|
|
func (s *Server) HandleHTTP() {
|
|
http.Handle(defaultRPCPath, s)
|
|
http.Handle(defaultDebugPath, debugHTTP{s})
|
|
log.Println("rpc: server debug path:", defaultDebugPath)
|
|
}
|
|
|
|
func HandleHTTP() {
|
|
DefaultServer.HandleHTTP()
|
|
}
|