feat: implement server and client
This commit is contained in:
67
.gitignore
vendored
67
.gitignore
vendored
@ -22,58 +22,8 @@
|
|||||||
go.work
|
go.work
|
||||||
|
|
||||||
# ---> JetBrains
|
# ---> JetBrains
|
||||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
.idea/
|
||||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
|
||||||
|
|
||||||
# User-specific stuff
|
|
||||||
.idea/**/workspace.xml
|
|
||||||
.idea/**/tasks.xml
|
|
||||||
.idea/**/usage.statistics.xml
|
|
||||||
.idea/**/dictionaries
|
|
||||||
.idea/**/shelf
|
|
||||||
|
|
||||||
# AWS User-specific
|
|
||||||
.idea/**/aws.xml
|
|
||||||
|
|
||||||
# Generated files
|
|
||||||
.idea/**/contentModel.xml
|
|
||||||
|
|
||||||
# Sensitive or high-churn files
|
|
||||||
.idea/**/dataSources/
|
|
||||||
.idea/**/dataSources.ids
|
|
||||||
.idea/**/dataSources.local.xml
|
|
||||||
.idea/**/sqlDataSources.xml
|
|
||||||
.idea/**/dynamic.xml
|
|
||||||
.idea/**/uiDesigner.xml
|
|
||||||
.idea/**/dbnavigator.xml
|
|
||||||
|
|
||||||
# Gradle
|
|
||||||
.idea/**/gradle.xml
|
|
||||||
.idea/**/libraries
|
|
||||||
|
|
||||||
# Gradle and Maven with auto-import
|
|
||||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
|
||||||
# since they will be recreated, and may cause churn. Uncomment if using
|
|
||||||
# auto-import.
|
|
||||||
# .idea/artifacts
|
|
||||||
# .idea/compiler.xml
|
|
||||||
# .idea/jarRepositories.xml
|
|
||||||
# .idea/modules.xml
|
|
||||||
# .idea/*.iml
|
|
||||||
# .idea/modules
|
|
||||||
# *.iml
|
|
||||||
# *.ipr
|
|
||||||
|
|
||||||
# CMake
|
|
||||||
cmake-build-*/
|
|
||||||
|
|
||||||
# Mongo Explorer plugin
|
|
||||||
.idea/**/mongoSettings.xml
|
|
||||||
|
|
||||||
# File-based project format
|
|
||||||
*.iws
|
|
||||||
|
|
||||||
# IntelliJ
|
|
||||||
out/
|
out/
|
||||||
|
|
||||||
# mpeltonen/sbt-idea plugin
|
# mpeltonen/sbt-idea plugin
|
||||||
@ -82,24 +32,12 @@ out/
|
|||||||
# JIRA plugin
|
# JIRA plugin
|
||||||
atlassian-ide-plugin.xml
|
atlassian-ide-plugin.xml
|
||||||
|
|
||||||
# Cursive Clojure plugin
|
|
||||||
.idea/replstate.xml
|
|
||||||
|
|
||||||
# SonarLint plugin
|
|
||||||
.idea/sonarlint/
|
|
||||||
|
|
||||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||||
com_crashlytics_export_strings.xml
|
com_crashlytics_export_strings.xml
|
||||||
crashlytics.properties
|
crashlytics.properties
|
||||||
crashlytics-build.properties
|
crashlytics-build.properties
|
||||||
fabric.properties
|
fabric.properties
|
||||||
|
|
||||||
# Editor-based Rest Client
|
|
||||||
.idea/httpRequests
|
|
||||||
|
|
||||||
# Android studio 3.1+ serialized cache file
|
|
||||||
.idea/caches/build_file_checksums.ser
|
|
||||||
|
|
||||||
# ---> VisualStudioCode
|
# ---> VisualStudioCode
|
||||||
.vscode/*
|
.vscode/*
|
||||||
!.vscode/settings.json
|
!.vscode/settings.json
|
||||||
@ -121,7 +59,8 @@ fabric.properties
|
|||||||
.LSOverride
|
.LSOverride
|
||||||
|
|
||||||
# Icon must end with two \r
|
# Icon must end with two \r
|
||||||
Icon
|
Icon
|
||||||
|
|
||||||
|
|
||||||
# Thumbnails
|
# Thumbnails
|
||||||
._*
|
._*
|
||||||
|
243
client.go
Normal file
243
client.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package krpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"krwu.top/krpc.v1/codec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Call struct {
|
||||||
|
Seq uint64
|
||||||
|
ServiceMethod string
|
||||||
|
Args interface{}
|
||||||
|
Reply interface{}
|
||||||
|
Error error
|
||||||
|
Done chan *Call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (call *Call) done() {
|
||||||
|
call.Done <- call
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
cc codec.Codec
|
||||||
|
opt *Options
|
||||||
|
sending sync.Mutex
|
||||||
|
header codec.Header
|
||||||
|
mu sync.Mutex
|
||||||
|
seq uint64
|
||||||
|
pending map[uint64]*Call
|
||||||
|
closing bool
|
||||||
|
shutdown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientResult struct {
|
||||||
|
client *Client
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type newClientFunc func(conn net.Conn, opts ...Option) (client *Client, err error)
|
||||||
|
|
||||||
|
var _ io.Closer = (*Client)(nil)
|
||||||
|
|
||||||
|
var ErrShutdown = errors.New("connection is shutdown")
|
||||||
|
|
||||||
|
func NewClient(conn net.Conn, opts ...Option) (*Client, error) {
|
||||||
|
for i := range opts {
|
||||||
|
opts[i](DefaultOptions)
|
||||||
|
}
|
||||||
|
f := codec.NewCodecFuncMap[DefaultOptions.CodecType]
|
||||||
|
if f == nil {
|
||||||
|
err := fmt.Errorf("invalid codec type %s", DefaultOptions.CodecType)
|
||||||
|
log.Println("rpc client: codec error: ", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(conn).Encode(DefaultOptions); err != nil {
|
||||||
|
log.Println("rpc client: options error:", err)
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newClientCodec(f(conn), DefaultOptions), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dial(network, address string, opts ...Option) (client *Client, err error) {
|
||||||
|
return dialTimeout(NewClient, network, address, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientCodec(cc codec.Codec, opts *Options) *Client {
|
||||||
|
client := &Client{
|
||||||
|
seq: 1,
|
||||||
|
cc: cc,
|
||||||
|
opt: opts,
|
||||||
|
pending: make(map[uint64]*Call),
|
||||||
|
}
|
||||||
|
go client.receive()
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Close() error {
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
if client.closing {
|
||||||
|
return ErrShutdown
|
||||||
|
}
|
||||||
|
client.closing = true
|
||||||
|
return client.cc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) IsAvailable() bool {
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
return !client.shutdown && !client.closing
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) registerCall(call *Call) (uint64, error) {
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
if client.closing || client.shutdown {
|
||||||
|
return 0, ErrShutdown
|
||||||
|
}
|
||||||
|
call.Seq = client.seq
|
||||||
|
client.pending[call.Seq] = call
|
||||||
|
atomic.StoreUint64(&client.seq, client.seq+1)
|
||||||
|
return call.Seq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) removeCall(seq uint64) *Call {
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
call := client.pending[seq]
|
||||||
|
delete(client.pending, seq)
|
||||||
|
return call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) terminateCalls(err error) {
|
||||||
|
client.sending.Lock()
|
||||||
|
defer client.sending.Unlock()
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
client.shutdown = true
|
||||||
|
for i := range client.pending {
|
||||||
|
client.pending[i].Error = err
|
||||||
|
client.pending[i].done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) receive() {
|
||||||
|
var err error
|
||||||
|
for err == nil {
|
||||||
|
var h codec.Header
|
||||||
|
if err = client.cc.ReadHeader(&h); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
call := client.removeCall(h.Seq)
|
||||||
|
switch {
|
||||||
|
case call == nil:
|
||||||
|
err = client.cc.ReadBody(nil)
|
||||||
|
case h.Error != "":
|
||||||
|
call.Error = fmt.Errorf(h.Error)
|
||||||
|
err = client.cc.ReadBody(nil)
|
||||||
|
call.done()
|
||||||
|
default:
|
||||||
|
err = client.cc.ReadBody(call.Reply)
|
||||||
|
if err != nil {
|
||||||
|
call.Error = fmt.Errorf("read body %s", err.Error())
|
||||||
|
}
|
||||||
|
call.done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client.terminateCalls(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) send(call *Call) {
|
||||||
|
client.sending.Lock()
|
||||||
|
defer client.sending.Unlock()
|
||||||
|
|
||||||
|
seq, err := client.registerCall(call)
|
||||||
|
if err != nil {
|
||||||
|
call.Error = err
|
||||||
|
call.done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client.header.ServiceMethod = call.ServiceMethod
|
||||||
|
client.header.Seq = seq
|
||||||
|
client.header.Error = ""
|
||||||
|
|
||||||
|
if err := client.cc.Write(&client.header, call.Args); err != nil {
|
||||||
|
call := client.removeCall(seq)
|
||||||
|
if call != nil {
|
||||||
|
call.Error = err
|
||||||
|
call.done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Go(serviceMethod string,
|
||||||
|
args, reply interface{}, done chan *Call) *Call {
|
||||||
|
if done == nil {
|
||||||
|
done = make(chan *Call, 10)
|
||||||
|
} else if cap(done) == 0 {
|
||||||
|
log.Panic("rpc client: done channel is unbuffered")
|
||||||
|
}
|
||||||
|
call := &Call{
|
||||||
|
ServiceMethod: serviceMethod,
|
||||||
|
Args: args,
|
||||||
|
Reply: reply,
|
||||||
|
Done: done,
|
||||||
|
}
|
||||||
|
go client.send(call)
|
||||||
|
return call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
|
||||||
|
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
client.removeCall(call.Seq)
|
||||||
|
return fmt.Errorf("rpc: client call failed: %v", ctx.Err())
|
||||||
|
case c := <-call.Done:
|
||||||
|
return c.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialTimeout(f newClientFunc, network, address string,
|
||||||
|
opts ...Option) (client *Client, err error) {
|
||||||
|
|
||||||
|
o := apply(DefaultOptions, opts...)
|
||||||
|
conn, err := net.DialTimeout(network, address, o.ConnectTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ch := make(chan clientResult)
|
||||||
|
go func() {
|
||||||
|
client, err := f(conn, opts...)
|
||||||
|
ch <- clientResult{client: client, err: err}
|
||||||
|
}()
|
||||||
|
if o.ConnectTimeout == 0 {
|
||||||
|
result := <-ch
|
||||||
|
return result.client, result.err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(o.ConnectTimeout):
|
||||||
|
return nil, fmt.Errorf("rpc: client connection timeout after %s", o.ConnectTimeout)
|
||||||
|
case result := <-ch:
|
||||||
|
return result.client, result.err
|
||||||
|
}
|
||||||
|
}
|
@ -1,21 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import "krwu.top/krpc.v1/codec"
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
MagicNumber int
|
|
||||||
CodecType codec.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultOptions = &Options{
|
|
||||||
MagicNumber: codec.MagicNumber,
|
|
||||||
CodecType: codec.GobType,
|
|
||||||
}
|
|
||||||
|
|
||||||
type Option func(*Options)
|
|
||||||
|
|
||||||
func WithCodecType(t codec.Type) Option {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.CodecType = t
|
|
||||||
}
|
|
||||||
}
|
|
40
client_test.go
Normal file
40
client_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package krpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_dialTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
l, _ := net.Listen("tcp", ":0")
|
||||||
|
|
||||||
|
f := func(conn net.Conn, opts ...Option) (client *Client, err error) {
|
||||||
|
_ = conn.Close()
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]time.Duration{
|
||||||
|
"timeout": time.Millisecond * 100,
|
||||||
|
"unlimited": 0,
|
||||||
|
}
|
||||||
|
for name, timeout := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
_, err := dialTimeout(f, "tcp", l.Addr().String(),
|
||||||
|
WithConnectTimeout(timeout))
|
||||||
|
if strings.Contains(name, "timeout") {
|
||||||
|
_assert(err != nil &&
|
||||||
|
strings.Contains(err.Error(), "timeout"),
|
||||||
|
"expect a timeout error",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
_assert(err == nil, "0 means no limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
51
example/main.go
Normal file
51
example/main.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"krwu.top/krpc.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func startServer(addr chan string) {
|
||||||
|
var foo Foo
|
||||||
|
if err := krpc.Register(&foo); err != nil {
|
||||||
|
log.Fatal("register error: ", err)
|
||||||
|
}
|
||||||
|
// pick a free port
|
||||||
|
l, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("network error:", err)
|
||||||
|
}
|
||||||
|
log.Println("start rpc server on", l.Addr())
|
||||||
|
addr <- l.Addr().String()
|
||||||
|
krpc.Accept(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log.SetFlags(0)
|
||||||
|
addr := make(chan string)
|
||||||
|
go startServer(addr)
|
||||||
|
|
||||||
|
cli, _ := krpc.Dial("tcp", <-addr)
|
||||||
|
defer func() { _ = cli.Close() }()
|
||||||
|
|
||||||
|
// send options
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// send request & receive response
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
args := &Args{Num1: i, Num2: i * i}
|
||||||
|
var reply int
|
||||||
|
if err := cli.Call(context.TODO(), "Foo.Sum", args, &reply); err != nil {
|
||||||
|
log.Fatal("call Foo.Sum error: ", err)
|
||||||
|
}
|
||||||
|
log.Printf("%d + %d = %d\n", args.Num1, args.Num2, reply)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
12
example/service.go
Normal file
12
example/service.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
|
||||||
|
type Args struct {
|
||||||
|
Num1, Num2 int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Foo) Sum(args Args, reply *int) error {
|
||||||
|
*reply = args.Num1 + args.Num2
|
||||||
|
return nil
|
||||||
|
}
|
9
go.mod
9
go.mod
@ -1,12 +1,3 @@
|
|||||||
module krwu.top/krpc.v1
|
module krwu.top/krpc.v1
|
||||||
|
|
||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require github.com/stretchr/testify v1.7.1
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/davecgh/go-spew v1.1.0 // indirect
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
|
||||||
github.com/stretchr/objx v0.1.0 // indirect
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
|
||||||
)
|
|
||||||
|
12
go.sum
12
go.sum
@ -1,12 +0,0 @@
|
|||||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
|
48
option.go
Normal file
48
option.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package krpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"krwu.top/krpc.v1/codec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
MagicNumber int
|
||||||
|
CodecType codec.Type
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
HandleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultOptions = &Options{
|
||||||
|
MagicNumber: codec.MagicNumber,
|
||||||
|
CodecType: codec.GobType,
|
||||||
|
ConnectTimeout: time.Second * 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*Options)
|
||||||
|
|
||||||
|
func WithCodecType(t codec.Type) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.CodecType = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithConnectTimeout(t time.Duration) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.ConnectTimeout = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHandleTimeout(t time.Duration) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.HandleTimeout = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func apply(o *Options, opts ...Option) Options {
|
||||||
|
newOpts := *o
|
||||||
|
for i := range opts {
|
||||||
|
opts[i](&newOpts)
|
||||||
|
}
|
||||||
|
return newOpts
|
||||||
|
}
|
58
request.go
58
request.go
@ -5,14 +5,16 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"krwu.top/krpc.v1/codec"
|
"krwu.top/krpc.v1/codec"
|
||||||
)
|
)
|
||||||
|
|
||||||
type request struct {
|
type request struct {
|
||||||
h *codec.Header
|
h *codec.Header
|
||||||
argv reflect.Value
|
argv, replyv reflect.Value
|
||||||
replyv reflect.Value
|
mtype *methodType
|
||||||
|
svc *service
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
|
func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
|
||||||
@ -32,16 +34,52 @@ func (s *Server) readRequest(cc codec.Codec) (*request, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req := &request{h: h}
|
req := &request{h: h}
|
||||||
req.argv = reflect.New(reflect.TypeOf(""))
|
req.svc, req.mtype, err = s.findService(h.ServiceMethod)
|
||||||
if err = cc.ReadBody(req.argv.Interface()); err != nil {
|
if err != nil {
|
||||||
fmt.Println("rpc server: read argv err: ", err)
|
return req, err
|
||||||
|
}
|
||||||
|
req.argv = req.mtype.newArgv()
|
||||||
|
req.replyv = req.mtype.newReplyv()
|
||||||
|
argvi := req.argv.Interface()
|
||||||
|
|
||||||
|
if req.argv.Type().Kind() != reflect.Ptr {
|
||||||
|
argvi = req.argv.Addr().Interface()
|
||||||
|
}
|
||||||
|
if err = cc.ReadBody(argvi); err != nil {
|
||||||
|
fmt.Println("rpc: server read body error: ", err)
|
||||||
|
return req, err
|
||||||
}
|
}
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
|
func (s *Server) handleRequest(cc codec.Codec,
|
||||||
|
req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
fmt.Println("rcp server: ", req.h, req.argv.Elem())
|
called := make(chan struct{}, 1)
|
||||||
req.replyv = reflect.ValueOf(fmt.Sprintf("krpc resp %d", req.h.Seq))
|
sent := make(chan struct{}, 1)
|
||||||
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
|
go func() {
|
||||||
|
err := req.svc.call(req.mtype, req.argv, req.replyv)
|
||||||
|
called <- struct{}{}
|
||||||
|
if err != nil {
|
||||||
|
req.h.Error = err.Error()
|
||||||
|
s.sendResponse(cc, req.h, invalidRequest, sending)
|
||||||
|
sent <- struct{}{}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
|
||||||
|
sent <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if timeout == 0 {
|
||||||
|
<-called
|
||||||
|
<-sent
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(timeout):
|
||||||
|
req.h.Error = fmt.Sprintf("prc: server request handle timeout with %s", timeout)
|
||||||
|
s.sendResponse(cc, req.h, invalidRequest, sending)
|
||||||
|
case <-called:
|
||||||
|
<-sent
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
43
server.go
43
server.go
@ -5,13 +5,16 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"krwu.top/krpc.v1/client"
|
|
||||||
"krwu.top/krpc.v1/codec"
|
"krwu.top/krpc.v1/codec"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct{}
|
type Server struct {
|
||||||
|
serviceMap sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
return &Server{}
|
return &Server{}
|
||||||
@ -34,7 +37,7 @@ func (s *Server) ServeConn(conn io.ReadWriteCloser) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}()
|
}()
|
||||||
var opts client.Options
|
var opts Options
|
||||||
if err := json.NewDecoder(conn).Decode(&opts); err != nil {
|
if err := json.NewDecoder(conn).Decode(&opts); err != nil {
|
||||||
fmt.Println("rpc server: options error: ", err)
|
fmt.Println("rpc server: options error: ", err)
|
||||||
return
|
return
|
||||||
@ -68,10 +71,42 @@ func (s *Server) ServeCodec(cc codec.Codec) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go s.handleRequest(cc, req, sending, wg)
|
go s.handleRequest(cc, req, sending, wg, time.Second*3)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
_ = cc.Close()
|
_ = cc.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) Register(rcvr interface{}) error {
|
||||||
|
svc := newService(rcvr)
|
||||||
|
if _, dup := s.serviceMap.LoadOrStore(svc.name, svc); dup {
|
||||||
|
return fmt.Errorf("rpc: service already defined: %s\n", svc.name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
|
||||||
|
dot := strings.LastIndex(serviceMethod, ".")
|
||||||
|
if dot < 0 {
|
||||||
|
err = fmt.Errorf("rpc: service/method request ill-formed: %s", serviceMethod)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
|
||||||
|
svci, ok := s.serviceMap.Load(serviceName)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("rpc: server can't find service %s", serviceName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
svc = svci.(*service)
|
||||||
|
mtype = svc.method[methodName]
|
||||||
|
if mtype == nil {
|
||||||
|
err = fmt.Errorf("rpc: server can't find method %s", methodName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func Register(rcvr interface{}) error {
|
||||||
|
return DefaultServer.Register(rcvr)
|
||||||
|
}
|
||||||
|
|
||||||
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
|
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
|
||||||
|
99
service.go
Normal file
99
service.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package krpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go/ast"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type methodType struct {
|
||||||
|
method reflect.Method
|
||||||
|
ArgType reflect.Type
|
||||||
|
ReplyType reflect.Type
|
||||||
|
numCalls uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *methodType) NumCalls() uint64 {
|
||||||
|
return atomic.LoadUint64(&m.numCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *methodType) newArgv() reflect.Value {
|
||||||
|
var argv reflect.Value
|
||||||
|
if m.ArgType.Kind() == reflect.Ptr {
|
||||||
|
argv = reflect.New(m.ArgType.Elem())
|
||||||
|
} else {
|
||||||
|
argv = reflect.New(m.ArgType).Elem()
|
||||||
|
}
|
||||||
|
return argv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *methodType) newReplyv() reflect.Value {
|
||||||
|
replyv := reflect.New(m.ReplyType.Elem())
|
||||||
|
switch m.ReplyType.Elem().Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
|
||||||
|
case reflect.Slice:
|
||||||
|
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
return replyv
|
||||||
|
}
|
||||||
|
|
||||||
|
type service struct {
|
||||||
|
name string
|
||||||
|
typ reflect.Type
|
||||||
|
rcvr reflect.Value
|
||||||
|
method map[string]*methodType
|
||||||
|
}
|
||||||
|
|
||||||
|
func newService(rcvr interface{}) *service {
|
||||||
|
s := new(service)
|
||||||
|
s.rcvr = reflect.ValueOf(rcvr)
|
||||||
|
s.name = reflect.Indirect(s.rcvr).Type().Name()
|
||||||
|
s.typ = reflect.TypeOf(rcvr)
|
||||||
|
if !ast.IsExported(s.name) {
|
||||||
|
log.Fatalf("rpc server: %s is not a valid service name", s.name)
|
||||||
|
}
|
||||||
|
s.registerMethods()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *service) registerMethods() {
|
||||||
|
s.method = make(map[string]*methodType, s.typ.NumMethod())
|
||||||
|
for i := 0; i < s.typ.NumMethod(); i++ {
|
||||||
|
method := s.typ.Method(i)
|
||||||
|
mType := method.Type
|
||||||
|
if mType.NumIn() != 3 || mType.NumOut() != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
argType, replyType := mType.In(1), mType.In(2)
|
||||||
|
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.method[method.Name] = &methodType{
|
||||||
|
method: method,
|
||||||
|
ArgType: argType,
|
||||||
|
ReplyType: replyType,
|
||||||
|
}
|
||||||
|
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func isExportedOrBuiltinType(t reflect.Type) bool {
|
||||||
|
return ast.IsExported(t.Name()) || t.PkgPath() == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
|
||||||
|
atomic.AddUint64(&m.numCalls, 1)
|
||||||
|
f := m.method.Func
|
||||||
|
returnValues := f.Call([]reflect.Value{
|
||||||
|
s.rcvr, argv, replyv,
|
||||||
|
})
|
||||||
|
if errInter := returnValues[0].Interface(); errInter != nil {
|
||||||
|
return errInter.(error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
48
service_test.go
Normal file
48
service_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package krpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
|
||||||
|
type Args struct{ Num1, Num2 int }
|
||||||
|
|
||||||
|
func (f Foo) Sum(args Args, reply *int) error {
|
||||||
|
*reply = args.Num1 + args.Num2
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Foo) sum(args Args, reply *int) error {
|
||||||
|
*reply = args.Num1 + args.Num2
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func _assert(condition bool, msg string, v ...interface{}) {
|
||||||
|
if !condition {
|
||||||
|
panic(fmt.Sprintf("assertion failed: "+msg+"\n", v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewService(t *testing.T) {
|
||||||
|
var foo Foo
|
||||||
|
s := newService(&foo)
|
||||||
|
want := 1
|
||||||
|
_assert(len(s.method) == want, "wrong service Method, expect %d, got %d", want, len(s.method))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodType_Call(t *testing.T) {
|
||||||
|
var foo Foo
|
||||||
|
s := newService(&foo)
|
||||||
|
mType := s.method["Sum"]
|
||||||
|
|
||||||
|
argv := mType.newArgv()
|
||||||
|
replyv := mType.newReplyv()
|
||||||
|
argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
|
||||||
|
err := s.call(mType, argv, replyv)
|
||||||
|
want := 4
|
||||||
|
_assert(err == nil && *replyv.Interface().(*int) == want &&
|
||||||
|
mType.NumCalls() == 1, "failed to call Foo.Sum")
|
||||||
|
}
|
Reference in New Issue
Block a user