feat: implement http support

Signed-off-by: Kairee Wu <kaireewu@gmail.com>
This commit is contained in:
2022-04-23 17:34:50 +08:00
parent 74ef998a9b
commit 4477a63c49
6 changed files with 150 additions and 4 deletions

View File

@ -1,6 +1,7 @@
package krpc package krpc
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@ -8,6 +9,8 @@ import (
"io" "io"
"log" "log"
"net" "net"
"net/http"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -241,3 +244,34 @@ func dialTimeout(f newClientFunc, network, address string,
return result.client, result.err return result.client, result.err
} }
} }
func NewHTTPClient(conn net.Conn, opts ...Option) (*Client, error) {
_, _ = fmt.Fprintf(conn,
"CONNECT %s HTTP/1.0\n\n", defaultRPCPath)
resp, err := http.ReadResponse(bufio.NewReader(conn),
&http.Request{Method: http.MethodConnect})
if err == nil && resp.Status == connected {
return NewClient(conn, opts...)
}
return nil, err
}
func DialHTTP(network, address string, opts ...Option) (*Client, error) {
return dialTimeout(NewHTTPClient, network, address, opts...)
}
func XDial(rpcAddr string, opts ...Option) (*Client, error) {
parts := strings.Split(rpcAddr, "://")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc: client wrong format '%s', expect protocol://addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
return Dial(protocol, addr, opts...)
}
}

View File

@ -2,6 +2,8 @@ package krpc
import ( import (
"net" "net"
"os"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -36,5 +38,24 @@ func TestClient_dialTimeout(t *testing.T) {
}) })
} }
}
func TestXDial(t *testing.T) {
if runtime.GOOS == "linux" {
ch := make(chan struct{})
addr := "/tmp/krpc.sock"
go func() {
_ = os.Remove(addr)
l, err := net.Listen("unix", addr)
if err != nil {
t.Fatal("failed to listen unix socket")
}
ch <- struct{}{}
Accept(l)
}()
<-ch
_, err := XDial("unix://"+addr)
_assert(err == nil, "failed to connect unix socket")
}
} }

54
debug.go Normal file
View File

@ -0,0 +1,54 @@
package krpc
import (
"fmt"
"net/http"
"text/template"
)
const debugText = `<html>
<head><title>kRPC Services</title></head>
<body>
{{range .}}
<hr>
Service {{.Name}}
<hr>
<table>
<tr>
<th>Method</th><th>Calls</th>
</tr>
<tr>
{{range $name, $mtype := .Method}}
<tr>
<td>{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error</td>
<td>{{$mtype.NumCalls}}
</tr>
{{end}}
</tr>
</table>
{{end}}
</body>`
var debug = template.Must(template.New("RPC debug").Parse(debugText))
type debugHTTP struct {
*Server
}
type debugService struct {
Name string
Method map[string]*methodType
}
func (s debugHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var services []debugService
s.serviceMap.Range(func(namei, svci interface{}) bool {
svc := svci.(*service)
services = append(services, debugService{Name: namei.(string), Method: svc.method,})
return true
})
err := debug.Execute(w, services)
if err != nil {
_, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
}
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"log" "log"
"net" "net"
"net/http"
"sync" "sync"
"krwu.top/krpc.v1" "krwu.top/krpc.v1"
@ -14,6 +15,7 @@ func startServer(addr chan string) {
if err := krpc.Register(&foo); err != nil { if err := krpc.Register(&foo); err != nil {
log.Fatal("register error: ", err) log.Fatal("register error: ", err)
} }
krpc.HandleHTTP()
// pick a free port // pick a free port
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
@ -21,7 +23,8 @@ func startServer(addr chan string) {
} }
log.Println("start rpc server on", l.Addr()) log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String() addr <- l.Addr().String()
krpc.Accept(l) _ = http.Serve(l, nil)
//krpc.Accept(l)
} }
func main() { func main() {
@ -29,7 +32,7 @@ func main() {
addr := make(chan string) addr := make(chan string)
go startServer(addr) go startServer(addr)
cli, _ := krpc.Dial("tcp", <-addr) cli, _ := krpc.DialHTTP("tcp", <-addr)
defer func() { _ = cli.Close() }() defer func() { _ = cli.Close() }()
// send options // send options

View File

@ -29,7 +29,7 @@ func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
} }
func (s *Server) readRequest(cc codec.Codec) (*request, error) { func (s *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := s.readRequestHeader(cc) h, err := s.readRequestHeader(cc)
if err != nil { if err != nil && err != io.EOF {
fmt.Println("rpc server: read request error: ", err) fmt.Println("rpc server: read request error: ", err)
return nil, err return nil, err
} }

View File

@ -4,7 +4,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/http"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -110,3 +112,35 @@ func Register(rcvr interface{}) error {
} }
func Accept(lis net.Listener) { DefaultServer.Accept(lis) } func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
const (
connected = "200 Connected to kRPC"
defaultRPCPath = "/_krpc_"
defaultDebugPath = "/debug/krpc"
)
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = fmt.Fprintln(w, "405 method not allowed")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Printf("rpc hijacking %s: %s", r.RemoteAddr, err.Error())
return
}
_, _ = fmt.Fprintf(conn, "HTTP/1.0 %s\n\n", connected)
s.ServeConn(conn)
}
func (s *Server) HandleHTTP() {
http.Handle(defaultRPCPath, s)
http.Handle(defaultDebugPath, debugHTTP{s})
log.Println("rpc: server debug path:", defaultDebugPath)
}
func HandleHTTP() {
DefaultServer.HandleHTTP()
}