diff --git a/client.go b/client.go
index 4f9c3de..8f0d849 100644
--- a/client.go
+++ b/client.go
@@ -1,6 +1,7 @@
package krpc
import (
+ "bufio"
"context"
"encoding/json"
"errors"
@@ -8,6 +9,8 @@ import (
"io"
"log"
"net"
+ "net/http"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -241,3 +244,34 @@ func dialTimeout(f newClientFunc, network, address string,
return result.client, result.err
}
}
+
+func NewHTTPClient(conn net.Conn, opts ...Option) (*Client, error) {
+ _, _ = fmt.Fprintf(conn,
+ "CONNECT %s HTTP/1.0\n\n", defaultRPCPath)
+
+ resp, err := http.ReadResponse(bufio.NewReader(conn),
+ &http.Request{Method: http.MethodConnect})
+ if err == nil && resp.Status == connected {
+ return NewClient(conn, opts...)
+ }
+ return nil, err
+}
+
+func DialHTTP(network, address string, opts ...Option) (*Client, error) {
+ return dialTimeout(NewHTTPClient, network, address, opts...)
+}
+
+func XDial(rpcAddr string, opts ...Option) (*Client, error) {
+ parts := strings.Split(rpcAddr, "://")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("rpc: client wrong format '%s', expect protocol://addr", rpcAddr)
+ }
+ protocol, addr := parts[0], parts[1]
+
+ switch protocol {
+ case "http":
+ return DialHTTP("tcp", addr, opts...)
+ default:
+ return Dial(protocol, addr, opts...)
+ }
+}
diff --git a/client_test.go b/client_test.go
index cc5251c..c954edc 100644
--- a/client_test.go
+++ b/client_test.go
@@ -2,6 +2,8 @@ package krpc
import (
"net"
+ "os"
+ "runtime"
"strings"
"testing"
"time"
@@ -36,5 +38,24 @@ func TestClient_dialTimeout(t *testing.T) {
})
}
-
+}
+
+func TestXDial(t *testing.T) {
+ if runtime.GOOS == "linux" {
+ ch := make(chan struct{})
+ addr := "/tmp/krpc.sock"
+
+ go func() {
+ _ = os.Remove(addr)
+ l, err := net.Listen("unix", addr)
+ if err != nil {
+ t.Fatal("failed to listen unix socket")
+ }
+ ch <- struct{}{}
+ Accept(l)
+ }()
+ <-ch
+ _, err := XDial("unix://"+addr)
+ _assert(err == nil, "failed to connect unix socket")
+ }
}
diff --git a/debug.go b/debug.go
new file mode 100644
index 0000000..a76afb9
--- /dev/null
+++ b/debug.go
@@ -0,0 +1,54 @@
+package krpc
+
+import (
+ "fmt"
+ "net/http"
+ "text/template"
+)
+
+const debugText = `
+
kRPC Services
+
+{{range .}}
+
+Service {{.Name}}
+
+
+
+Method | Calls |
+
+
+{{range $name, $mtype := .Method}}
+
+{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error |
+{{$mtype.NumCalls}}
+ |
+{{end}}
+
+
+{{end}}
+`
+
+var debug = template.Must(template.New("RPC debug").Parse(debugText))
+
+type debugHTTP struct {
+ *Server
+}
+
+type debugService struct {
+ Name string
+ Method map[string]*methodType
+}
+
+func (s debugHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ var services []debugService
+ s.serviceMap.Range(func(namei, svci interface{}) bool {
+ svc := svci.(*service)
+ services = append(services, debugService{Name: namei.(string), Method: svc.method,})
+ return true
+ })
+ err := debug.Execute(w, services)
+ if err != nil {
+ _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
+ }
+}
diff --git a/example/main.go b/example/main.go
index af19da4..799023c 100644
--- a/example/main.go
+++ b/example/main.go
@@ -4,6 +4,7 @@ import (
"context"
"log"
"net"
+ "net/http"
"sync"
"krwu.top/krpc.v1"
@@ -14,6 +15,7 @@ func startServer(addr chan string) {
if err := krpc.Register(&foo); err != nil {
log.Fatal("register error: ", err)
}
+ krpc.HandleHTTP()
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
@@ -21,7 +23,8 @@ func startServer(addr chan string) {
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
- krpc.Accept(l)
+ _ = http.Serve(l, nil)
+ //krpc.Accept(l)
}
func main() {
@@ -29,7 +32,7 @@ func main() {
addr := make(chan string)
go startServer(addr)
- cli, _ := krpc.Dial("tcp", <-addr)
+ cli, _ := krpc.DialHTTP("tcp", <-addr)
defer func() { _ = cli.Close() }()
// send options
diff --git a/request.go b/request.go
index a508e5f..1f61ff0 100644
--- a/request.go
+++ b/request.go
@@ -29,7 +29,7 @@ func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
}
func (s *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := s.readRequestHeader(cc)
- if err != nil {
+ if err != nil && err != io.EOF {
fmt.Println("rpc server: read request error: ", err)
return nil, err
}
diff --git a/server.go b/server.go
index ed71247..f9404e8 100644
--- a/server.go
+++ b/server.go
@@ -4,7 +4,9 @@ import (
"encoding/json"
"fmt"
"io"
+ "log"
"net"
+ "net/http"
"strings"
"sync"
"time"
@@ -110,3 +112,35 @@ func Register(rcvr interface{}) error {
}
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
+
+const (
+ connected = "200 Connected to kRPC"
+ defaultRPCPath = "/_krpc_"
+ defaultDebugPath = "/debug/krpc"
+)
+
+func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodConnect {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ _, _ = fmt.Fprintln(w, "405 method not allowed")
+ return
+ }
+ conn, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ log.Printf("rpc hijacking %s: %s", r.RemoteAddr, err.Error())
+ return
+ }
+ _, _ = fmt.Fprintf(conn, "HTTP/1.0 %s\n\n", connected)
+ s.ServeConn(conn)
+}
+
+func (s *Server) HandleHTTP() {
+ http.Handle(defaultRPCPath, s)
+ http.Handle(defaultDebugPath, debugHTTP{s})
+ log.Println("rpc: server debug path:", defaultDebugPath)
+}
+
+func HandleHTTP() {
+ DefaultServer.HandleHTTP()
+}