package krpc import ( "encoding/json" "fmt" "io" "net" "strings" "sync" "time" "krwu.top/krpc.v1/codec" ) type Server struct { serviceMap sync.Map } func NewServer() *Server { return &Server{} } var DefaultServer = NewServer() func (s *Server) Accept(lis net.Listener) { for { conn, err := lis.Accept() if err != nil { fmt.Println("rpc server: accept error: ", err) return } go s.ServeConn(conn) } } func (s *Server) ServeConn(conn io.ReadWriteCloser) { defer func() { _ = conn.Close() }() var opts Options if err := json.NewDecoder(conn).Decode(&opts); err != nil { fmt.Println("rpc server: options error: ", err) return } if opts.MagicNumber != codec.MagicNumber { fmt.Printf("rpc server: invalid magic number %x\n", opts.MagicNumber) return } f := codec.NewCodecFuncMap[opts.CodecType] if f == nil { fmt.Printf("rpc server: invalid codec type %s\n", opts.CodecType) return } s.ServeCodec(f(conn)) } var invalidRequest = struct{}{} func (s *Server) ServeCodec(cc codec.Codec) { sending := new(sync.Mutex) wg := new(sync.WaitGroup) for { req, err := s.readRequest(cc) if err != nil { fmt.Println("rpc server invalid request: ", err) if req == nil { break } req.h.Error = err.Error() s.sendResponse(cc, req.h, invalidRequest, sending) continue } wg.Add(1) go s.handleRequest(cc, req, sending, wg, time.Second*3) } wg.Wait() _ = cc.Close() } func (s *Server) Register(rcvr interface{}) error { svc := newService(rcvr) if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup { return fmt.Errorf("rpc: service already defined: %s\n", svc.name) } return nil } func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) { dot := strings.LastIndex(serviceMethod, ".") if dot < 0 { err = fmt.Errorf("rpc: service/method request ill-formed: %s", serviceMethod) return } serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:] svci, ok := s.serviceMap.Load(serviceName) if !ok { err = fmt.Errorf("rpc: server can't find service %s", serviceName) return } svc = svci.(*service) mtype = svc.method[methodName] if mtype == nil { err = fmt.Errorf("rpc: server can't find method %s", methodName) } return } func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } func Accept(lis net.Listener) { DefaultServer.Accept(lis) }