diff --git a/client.go b/client.go index 4f9c3de..8f0d849 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package krpc import ( + "bufio" "context" "encoding/json" "errors" @@ -8,6 +9,8 @@ import ( "io" "log" "net" + "net/http" + "strings" "sync" "sync/atomic" "time" @@ -241,3 +244,34 @@ func dialTimeout(f newClientFunc, network, address string, return result.client, result.err } } + +func NewHTTPClient(conn net.Conn, opts ...Option) (*Client, error) { + _, _ = fmt.Fprintf(conn, + "CONNECT %s HTTP/1.0\n\n", defaultRPCPath) + + resp, err := http.ReadResponse(bufio.NewReader(conn), + &http.Request{Method: http.MethodConnect}) + if err == nil && resp.Status == connected { + return NewClient(conn, opts...) + } + return nil, err +} + +func DialHTTP(network, address string, opts ...Option) (*Client, error) { + return dialTimeout(NewHTTPClient, network, address, opts...) +} + +func XDial(rpcAddr string, opts ...Option) (*Client, error) { + parts := strings.Split(rpcAddr, "://") + if len(parts) != 2 { + return nil, fmt.Errorf("rpc: client wrong format '%s', expect protocol://addr", rpcAddr) + } + protocol, addr := parts[0], parts[1] + + switch protocol { + case "http": + return DialHTTP("tcp", addr, opts...) + default: + return Dial(protocol, addr, opts...) + } +} diff --git a/client_test.go b/client_test.go index cc5251c..c954edc 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,8 @@ package krpc import ( "net" + "os" + "runtime" "strings" "testing" "time" @@ -36,5 +38,24 @@ func TestClient_dialTimeout(t *testing.T) { }) } - +} + +func TestXDial(t *testing.T) { + if runtime.GOOS == "linux" { + ch := make(chan struct{}) + addr := "/tmp/krpc.sock" + + go func() { + _ = os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + t.Fatal("failed to listen unix socket") + } + ch <- struct{}{} + Accept(l) + }() + <-ch + _, err := XDial("unix://"+addr) + _assert(err == nil, "failed to connect unix socket") + } } diff --git a/debug.go b/debug.go new file mode 100644 index 0000000..a76afb9 --- /dev/null +++ b/debug.go @@ -0,0 +1,54 @@ +package krpc + +import ( + "fmt" + "net/http" + "text/template" +) + +const debugText = ` +kRPC Services + +{{range .}} +
+Service {{.Name}} +
+ + + + + +{{range $name, $mtype := .Method}} + + + +{{end}} + +
MethodCalls
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error{{$mtype.NumCalls}} +
+{{end}} +` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +type debugHTTP struct { + *Server +} + +type debugService struct { + Name string + Method map[string]*methodType +} + +func (s debugHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var services []debugService + s.serviceMap.Range(func(namei, svci interface{}) bool { + svc := svci.(*service) + services = append(services, debugService{Name: namei.(string), Method: svc.method,}) + return true + }) + err := debug.Execute(w, services) + if err != nil { + _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/example/main.go b/example/main.go index af19da4..799023c 100644 --- a/example/main.go +++ b/example/main.go @@ -4,6 +4,7 @@ import ( "context" "log" "net" + "net/http" "sync" "krwu.top/krpc.v1" @@ -14,6 +15,7 @@ func startServer(addr chan string) { if err := krpc.Register(&foo); err != nil { log.Fatal("register error: ", err) } + krpc.HandleHTTP() // pick a free port l, err := net.Listen("tcp", ":0") if err != nil { @@ -21,7 +23,8 @@ func startServer(addr chan string) { } log.Println("start rpc server on", l.Addr()) addr <- l.Addr().String() - krpc.Accept(l) + _ = http.Serve(l, nil) + //krpc.Accept(l) } func main() { @@ -29,7 +32,7 @@ func main() { addr := make(chan string) go startServer(addr) - cli, _ := krpc.Dial("tcp", <-addr) + cli, _ := krpc.DialHTTP("tcp", <-addr) defer func() { _ = cli.Close() }() // send options diff --git a/request.go b/request.go index a508e5f..1f61ff0 100644 --- a/request.go +++ b/request.go @@ -29,7 +29,7 @@ func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { } func (s *Server) readRequest(cc codec.Codec) (*request, error) { h, err := s.readRequestHeader(cc) - if err != nil { + if err != nil && err != io.EOF { fmt.Println("rpc server: read request error: ", err) return nil, err } diff --git a/server.go b/server.go index ed71247..f9404e8 100644 --- a/server.go +++ b/server.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "io" + "log" "net" + "net/http" "strings" "sync" "time" @@ -110,3 +112,35 @@ func Register(rcvr interface{}) error { } func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +const ( + connected = "200 Connected to kRPC" + defaultRPCPath = "/_krpc_" + defaultDebugPath = "/debug/krpc" +) + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = fmt.Fprintln(w, "405 method not allowed") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Printf("rpc hijacking %s: %s", r.RemoteAddr, err.Error()) + return + } + _, _ = fmt.Fprintf(conn, "HTTP/1.0 %s\n\n", connected) + s.ServeConn(conn) +} + +func (s *Server) HandleHTTP() { + http.Handle(defaultRPCPath, s) + http.Handle(defaultDebugPath, debugHTTP{s}) + log.Println("rpc: server debug path:", defaultDebugPath) +} + +func HandleHTTP() { + DefaultServer.HandleHTTP() +}