feat: implement server and client

This commit is contained in:
2022-04-23 01:24:17 +08:00
parent cf8fa335a6
commit 74ef998a9b
13 changed files with 631 additions and 120 deletions

65
.gitignore vendored
View File

@ -22,58 +22,8 @@
go.work
# ---> JetBrains
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
.idea/
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
@ -82,24 +32,12 @@ out/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
# ---> VisualStudioCode
.vscode/*
!.vscode/settings.json
@ -123,6 +61,7 @@ fabric.properties
# Icon must end with two \r
Icon
# Thumbnails
._*

243
client.go Normal file
View File

@ -0,0 +1,243 @@
package krpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"krwu.top/krpc.v1/codec"
)
type Call struct {
Seq uint64
ServiceMethod string
Args interface{}
Reply interface{}
Error error
Done chan *Call
}
func (call *Call) done() {
call.Done <- call
}
type Client struct {
cc codec.Codec
opt *Options
sending sync.Mutex
header codec.Header
mu sync.Mutex
seq uint64
pending map[uint64]*Call
closing bool
shutdown bool
}
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opts ...Option) (client *Client, err error)
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shutdown")
func NewClient(conn net.Conn, opts ...Option) (*Client, error) {
for i := range opts {
opts[i](DefaultOptions)
}
f := codec.NewCodecFuncMap[DefaultOptions.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", DefaultOptions.CodecType)
log.Println("rpc client: codec error: ", err)
return nil, err
}
if err := json.NewEncoder(conn).Encode(DefaultOptions); err != nil {
log.Println("rpc client: options error:", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), DefaultOptions), nil
}
func Dial(network, address string, opts ...Option) (client *Client, err error) {
return dialTimeout(NewClient, network, address, opts...)
}
func newClientCodec(cc codec.Codec, opts *Options) *Client {
client := &Client{
seq: 1,
cc: cc,
opt: opts,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
atomic.StoreUint64(&client.seq, client.seq+1)
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for i := range client.pending {
client.pending[i].Error = err
client.pending[i].done()
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = fmt.Errorf("read body %s", err.Error())
}
call.done()
}
}
client.terminateCalls(err)
}
func (client *Client) send(call *Call) {
client.sending.Lock()
defer client.sending.Unlock()
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) Go(serviceMethod string,
args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
go client.send(call)
return call
}
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return fmt.Errorf("rpc: client call failed: %v", ctx.Err())
case c := <-call.Done:
return c.Error
}
}
func dialTimeout(f newClientFunc, network, address string,
opts ...Option) (client *Client, err error) {
o := apply(DefaultOptions, opts...)
conn, err := net.DialTimeout(network, address, o.ConnectTimeout)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opts...)
ch <- clientResult{client: client, err: err}
}()
if o.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(o.ConnectTimeout):
return nil, fmt.Errorf("rpc: client connection timeout after %s", o.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}

View File

@ -1,21 +0,0 @@
package client
import "krwu.top/krpc.v1/codec"
type Options struct {
MagicNumber int
CodecType codec.Type
}
var DefaultOptions = &Options{
MagicNumber: codec.MagicNumber,
CodecType: codec.GobType,
}
type Option func(*Options)
func WithCodecType(t codec.Type) Option {
return func(o *Options) {
o.CodecType = t
}
}

40
client_test.go Normal file
View File

@ -0,0 +1,40 @@
package krpc
import (
"net"
"strings"
"testing"
"time"
)
func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
l, _ := net.Listen("tcp", ":0")
f := func(conn net.Conn, opts ...Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Millisecond * 500)
return nil, nil
}
tests := map[string]time.Duration{
"timeout": time.Millisecond * 100,
"unlimited": 0,
}
for name, timeout := range tests {
t.Run(name, func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(),
WithConnectTimeout(timeout))
if strings.Contains(name, "timeout") {
_assert(err != nil &&
strings.Contains(err.Error(), "timeout"),
"expect a timeout error",
)
} else {
_assert(err == nil, "0 means no limit")
}
})
}
}

51
example/main.go Normal file
View File

@ -0,0 +1,51 @@
package main
import (
"context"
"log"
"net"
"sync"
"krwu.top/krpc.v1"
)
func startServer(addr chan string) {
var foo Foo
if err := krpc.Register(&foo); err != nil {
log.Fatal("register error: ", err)
}
// pick a free port
l, err := net.Listen("tcp", ":0")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
krpc.Accept(l)
}
func main() {
log.SetFlags(0)
addr := make(chan string)
go startServer(addr)
cli, _ := krpc.Dial("tcp", <-addr)
defer func() { _ = cli.Close() }()
// send options
var wg sync.WaitGroup
// send request & receive response
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := &Args{Num1: i, Num2: i * i}
var reply int
if err := cli.Call(context.TODO(), "Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error: ", err)
}
log.Printf("%d + %d = %d\n", args.Num1, args.Num2, reply)
}(i)
}
wg.Wait()
}

12
example/service.go Normal file
View File

@ -0,0 +1,12 @@
package main
type Foo int
type Args struct {
Num1, Num2 int
}
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}

9
go.mod
View File

@ -1,12 +1,3 @@
module krwu.top/krpc.v1
go 1.17
require github.com/stretchr/testify v1.7.1
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.1.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

12
go.sum
View File

@ -1,12 +0,0 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

48
option.go Normal file
View File

@ -0,0 +1,48 @@
package krpc
import (
"time"
"krwu.top/krpc.v1/codec"
)
type Options struct {
MagicNumber int
CodecType codec.Type
ConnectTimeout time.Duration
HandleTimeout time.Duration
}
var DefaultOptions = &Options{
MagicNumber: codec.MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 3,
}
type Option func(*Options)
func WithCodecType(t codec.Type) Option {
return func(o *Options) {
o.CodecType = t
}
}
func WithConnectTimeout(t time.Duration) Option {
return func(o *Options) {
o.ConnectTimeout = t
}
}
func WithHandleTimeout(t time.Duration) Option {
return func(o *Options) {
o.HandleTimeout = t
}
}
func apply(o *Options, opts ...Option) Options {
newOpts := *o
for i := range opts {
opts[i](&newOpts)
}
return newOpts
}

View File

@ -5,14 +5,16 @@ import (
"io"
"reflect"
"sync"
"time"
"krwu.top/krpc.v1/codec"
)
type request struct {
h *codec.Header
argv reflect.Value
replyv reflect.Value
argv, replyv reflect.Value
mtype *methodType
svc *service
}
func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
@ -32,16 +34,52 @@ func (s *Server) readRequest(cc codec.Codec) (*request, error) {
return nil, err
}
req := &request{h: h}
req.argv = reflect.New(reflect.TypeOf(""))
if err = cc.ReadBody(req.argv.Interface()); err != nil {
fmt.Println("rpc server: read argv err: ", err)
req.svc, req.mtype, err = s.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
fmt.Println("rpc: server read body error: ", err)
return req, err
}
return req, nil
}
func (s *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
func (s *Server) handleRequest(cc codec.Codec,
req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
fmt.Println("rcp server: ", req.h, req.argv.Elem())
req.replyv = reflect.ValueOf(fmt.Sprintf("krpc resp %d", req.h.Seq))
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
called := make(chan struct{}, 1)
sent := make(chan struct{}, 1)
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
s.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("prc: server request handle timeout with %s", timeout)
s.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}

View File

@ -5,13 +5,16 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"krwu.top/krpc.v1/client"
"krwu.top/krpc.v1/codec"
)
type Server struct{}
type Server struct {
serviceMap sync.Map
}
func NewServer() *Server {
return &Server{}
@ -34,7 +37,7 @@ func (s *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() {
_ = conn.Close()
}()
var opts client.Options
var opts Options
if err := json.NewDecoder(conn).Decode(&opts); err != nil {
fmt.Println("rpc server: options error: ", err)
return
@ -68,10 +71,42 @@ func (s *Server) ServeCodec(cc codec.Codec) {
continue
}
wg.Add(1)
go s.handleRequest(cc, req, sending, wg)
go s.handleRequest(cc, req, sending, wg, time.Second*3)
}
wg.Wait()
_ = cc.Close()
}
func (s *Server) Register(rcvr interface{}) error {
svc := newService(rcvr)
if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup {
return fmt.Errorf("rpc: service already defined: %s\n", svc.name)
}
return nil
}
func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = fmt.Errorf("rpc: service/method request ill-formed: %s", serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := s.serviceMap.Load(serviceName)
if !ok {
err = fmt.Errorf("rpc: server can't find service %s", serviceName)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = fmt.Errorf("rpc: server can't find method %s", methodName)
}
return
}
func Register(rcvr interface{}) error {
return DefaultServer.Register(rcvr)
}
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }

99
service.go Normal file
View File

@ -0,0 +1,99 @@
package krpc
import (
"go/ast"
"log"
"reflect"
"sync/atomic"
)
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint64
}
func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
if m.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}
func (m *methodType) newReplyv() reflect.Value {
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
type service struct {
name string
typ reflect.Type
rcvr reflect.Value
method map[string]*methodType
}
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.method = make(map[string]*methodType, s.typ.NumMethod())
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{
s.rcvr, argv, replyv,
})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}

48
service_test.go Normal file
View File

@ -0,0 +1,48 @@
package krpc
import (
"fmt"
"reflect"
"testing"
)
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func (f Foo) sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func _assert(condition bool, msg string, v ...interface{}) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg+"\n", v...))
}
}
func TestNewService(t *testing.T) {
var foo Foo
s := newService(&foo)
want := 1
_assert(len(s.method) == want, "wrong service Method, expect %d, got %d", want, len(s.method))
}
func TestMethodType_Call(t *testing.T) {
var foo Foo
s := newService(&foo)
mType := s.method["Sum"]
argv := mType.newArgv()
replyv := mType.newReplyv()
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
err := s.call(mType, argv, replyv)
want := 4
_assert(err == nil && *replyv.Interface().(*int) == want &&
mType.NumCalls() == 1, "failed to call Foo.Sum")
}