feat: implement http support
Signed-off-by: Kairee Wu <kaireewu@gmail.com>
This commit is contained in:
34
client.go
34
client.go
@ -1,6 +1,7 @@
|
||||
package krpc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@ -8,6 +9,8 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -241,3 +244,34 @@ func dialTimeout(f newClientFunc, network, address string,
|
||||
return result.client, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func NewHTTPClient(conn net.Conn, opts ...Option) (*Client, error) {
|
||||
_, _ = fmt.Fprintf(conn,
|
||||
"CONNECT %s HTTP/1.0\n\n", defaultRPCPath)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn),
|
||||
&http.Request{Method: http.MethodConnect})
|
||||
if err == nil && resp.Status == connected {
|
||||
return NewClient(conn, opts...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func DialHTTP(network, address string, opts ...Option) (*Client, error) {
|
||||
return dialTimeout(NewHTTPClient, network, address, opts...)
|
||||
}
|
||||
|
||||
func XDial(rpcAddr string, opts ...Option) (*Client, error) {
|
||||
parts := strings.Split(rpcAddr, "://")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("rpc: client wrong format '%s', expect protocol://addr", rpcAddr)
|
||||
}
|
||||
protocol, addr := parts[0], parts[1]
|
||||
|
||||
switch protocol {
|
||||
case "http":
|
||||
return DialHTTP("tcp", addr, opts...)
|
||||
default:
|
||||
return Dial(protocol, addr, opts...)
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package krpc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -36,5 +38,24 @@ func TestClient_dialTimeout(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestXDial(t *testing.T) {
|
||||
if runtime.GOOS == "linux" {
|
||||
ch := make(chan struct{})
|
||||
addr := "/tmp/krpc.sock"
|
||||
|
||||
go func() {
|
||||
_ = os.Remove(addr)
|
||||
l, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
t.Fatal("failed to listen unix socket")
|
||||
}
|
||||
ch <- struct{}{}
|
||||
Accept(l)
|
||||
}()
|
||||
<-ch
|
||||
_, err := XDial("unix://"+addr)
|
||||
_assert(err == nil, "failed to connect unix socket")
|
||||
}
|
||||
}
|
||||
|
54
debug.go
Normal file
54
debug.go
Normal file
@ -0,0 +1,54 @@
|
||||
package krpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
const debugText = `<html>
|
||||
<head><title>kRPC Services</title></head>
|
||||
<body>
|
||||
{{range .}}
|
||||
<hr>
|
||||
Service {{.Name}}
|
||||
<hr>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Method</th><th>Calls</th>
|
||||
</tr>
|
||||
<tr>
|
||||
{{range $name, $mtype := .Method}}
|
||||
<tr>
|
||||
<td>{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error</td>
|
||||
<td>{{$mtype.NumCalls}}
|
||||
</tr>
|
||||
{{end}}
|
||||
</tr>
|
||||
</table>
|
||||
{{end}}
|
||||
</body>`
|
||||
|
||||
var debug = template.Must(template.New("RPC debug").Parse(debugText))
|
||||
|
||||
type debugHTTP struct {
|
||||
*Server
|
||||
}
|
||||
|
||||
type debugService struct {
|
||||
Name string
|
||||
Method map[string]*methodType
|
||||
}
|
||||
|
||||
func (s debugHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var services []debugService
|
||||
s.serviceMap.Range(func(namei, svci interface{}) bool {
|
||||
svc := svci.(*service)
|
||||
services = append(services, debugService{Name: namei.(string), Method: svc.method,})
|
||||
return true
|
||||
})
|
||||
err := debug.Execute(w, services)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"krwu.top/krpc.v1"
|
||||
@ -14,6 +15,7 @@ func startServer(addr chan string) {
|
||||
if err := krpc.Register(&foo); err != nil {
|
||||
log.Fatal("register error: ", err)
|
||||
}
|
||||
krpc.HandleHTTP()
|
||||
// pick a free port
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
@ -21,7 +23,8 @@ func startServer(addr chan string) {
|
||||
}
|
||||
log.Println("start rpc server on", l.Addr())
|
||||
addr <- l.Addr().String()
|
||||
krpc.Accept(l)
|
||||
_ = http.Serve(l, nil)
|
||||
//krpc.Accept(l)
|
||||
}
|
||||
|
||||
func main() {
|
||||
@ -29,7 +32,7 @@ func main() {
|
||||
addr := make(chan string)
|
||||
go startServer(addr)
|
||||
|
||||
cli, _ := krpc.Dial("tcp", <-addr)
|
||||
cli, _ := krpc.DialHTTP("tcp", <-addr)
|
||||
defer func() { _ = cli.Close() }()
|
||||
|
||||
// send options
|
||||
|
@ -29,7 +29,7 @@ func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
|
||||
}
|
||||
func (s *Server) readRequest(cc codec.Codec) (*request, error) {
|
||||
h, err := s.readRequestHeader(cc)
|
||||
if err != nil {
|
||||
if err != nil && err != io.EOF {
|
||||
fmt.Println("rpc server: read request error: ", err)
|
||||
return nil, err
|
||||
}
|
||||
|
34
server.go
34
server.go
@ -4,7 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -110,3 +112,35 @@ func Register(rcvr interface{}) error {
|
||||
}
|
||||
|
||||
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
|
||||
|
||||
const (
|
||||
connected = "200 Connected to kRPC"
|
||||
defaultRPCPath = "/_krpc_"
|
||||
defaultDebugPath = "/debug/krpc"
|
||||
)
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
_, _ = fmt.Fprintln(w, "405 method not allowed")
|
||||
return
|
||||
}
|
||||
conn, _, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
log.Printf("rpc hijacking %s: %s", r.RemoteAddr, err.Error())
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(conn, "HTTP/1.0 %s\n\n", connected)
|
||||
s.ServeConn(conn)
|
||||
}
|
||||
|
||||
func (s *Server) HandleHTTP() {
|
||||
http.Handle(defaultRPCPath, s)
|
||||
http.Handle(defaultDebugPath, debugHTTP{s})
|
||||
log.Println("rpc: server debug path:", defaultDebugPath)
|
||||
}
|
||||
|
||||
func HandleHTTP() {
|
||||
DefaultServer.HandleHTTP()
|
||||
}
|
||||
|
Reference in New Issue
Block a user