244 lines
5.1 KiB
Go
244 lines
5.1 KiB
Go
package krpc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"krwu.top/krpc.v1/codec"
|
|
)
|
|
|
|
type Call struct {
|
|
Seq uint64
|
|
ServiceMethod string
|
|
Args interface{}
|
|
Reply interface{}
|
|
Error error
|
|
Done chan *Call
|
|
}
|
|
|
|
func (call *Call) done() {
|
|
call.Done <- call
|
|
}
|
|
|
|
type Client struct {
|
|
cc codec.Codec
|
|
opt *Options
|
|
sending sync.Mutex
|
|
header codec.Header
|
|
mu sync.Mutex
|
|
seq uint64
|
|
pending map[uint64]*Call
|
|
closing bool
|
|
shutdown bool
|
|
}
|
|
|
|
type clientResult struct {
|
|
client *Client
|
|
err error
|
|
}
|
|
|
|
type newClientFunc func(conn net.Conn, opts ...Option) (client *Client, err error)
|
|
|
|
var _ io.Closer = (*Client)(nil)
|
|
|
|
var ErrShutdown = errors.New("connection is shutdown")
|
|
|
|
func NewClient(conn net.Conn, opts ...Option) (*Client, error) {
|
|
for i := range opts {
|
|
opts[i](DefaultOptions)
|
|
}
|
|
f := codec.NewCodecFuncMap[DefaultOptions.CodecType]
|
|
if f == nil {
|
|
err := fmt.Errorf("invalid codec type %s", DefaultOptions.CodecType)
|
|
log.Println("rpc client: codec error: ", err)
|
|
return nil, err
|
|
}
|
|
|
|
if err := json.NewEncoder(conn).Encode(DefaultOptions); err != nil {
|
|
log.Println("rpc client: options error:", err)
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
return newClientCodec(f(conn), DefaultOptions), nil
|
|
}
|
|
|
|
func Dial(network, address string, opts ...Option) (client *Client, err error) {
|
|
return dialTimeout(NewClient, network, address, opts...)
|
|
}
|
|
|
|
func newClientCodec(cc codec.Codec, opts *Options) *Client {
|
|
client := &Client{
|
|
seq: 1,
|
|
cc: cc,
|
|
opt: opts,
|
|
pending: make(map[uint64]*Call),
|
|
}
|
|
go client.receive()
|
|
return client
|
|
}
|
|
|
|
func (client *Client) Close() error {
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
if client.closing {
|
|
return ErrShutdown
|
|
}
|
|
client.closing = true
|
|
return client.cc.Close()
|
|
}
|
|
|
|
func (client *Client) IsAvailable() bool {
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
return !client.shutdown && !client.closing
|
|
}
|
|
|
|
func (client *Client) registerCall(call *Call) (uint64, error) {
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
if client.closing || client.shutdown {
|
|
return 0, ErrShutdown
|
|
}
|
|
call.Seq = client.seq
|
|
client.pending[call.Seq] = call
|
|
atomic.StoreUint64(&client.seq, client.seq+1)
|
|
return call.Seq, nil
|
|
}
|
|
|
|
func (client *Client) removeCall(seq uint64) *Call {
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
call := client.pending[seq]
|
|
delete(client.pending, seq)
|
|
return call
|
|
}
|
|
|
|
func (client *Client) terminateCalls(err error) {
|
|
client.sending.Lock()
|
|
defer client.sending.Unlock()
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
client.shutdown = true
|
|
for i := range client.pending {
|
|
client.pending[i].Error = err
|
|
client.pending[i].done()
|
|
}
|
|
}
|
|
|
|
func (client *Client) receive() {
|
|
var err error
|
|
for err == nil {
|
|
var h codec.Header
|
|
if err = client.cc.ReadHeader(&h); err != nil {
|
|
break
|
|
}
|
|
call := client.removeCall(h.Seq)
|
|
switch {
|
|
case call == nil:
|
|
err = client.cc.ReadBody(nil)
|
|
case h.Error != "":
|
|
call.Error = fmt.Errorf(h.Error)
|
|
err = client.cc.ReadBody(nil)
|
|
call.done()
|
|
default:
|
|
err = client.cc.ReadBody(call.Reply)
|
|
if err != nil {
|
|
call.Error = fmt.Errorf("read body %s", err.Error())
|
|
}
|
|
call.done()
|
|
}
|
|
}
|
|
client.terminateCalls(err)
|
|
}
|
|
|
|
func (client *Client) send(call *Call) {
|
|
client.sending.Lock()
|
|
defer client.sending.Unlock()
|
|
|
|
seq, err := client.registerCall(call)
|
|
if err != nil {
|
|
call.Error = err
|
|
call.done()
|
|
return
|
|
}
|
|
|
|
client.header.ServiceMethod = call.ServiceMethod
|
|
client.header.Seq = seq
|
|
client.header.Error = ""
|
|
|
|
if err := client.cc.Write(&client.header, call.Args); err != nil {
|
|
call := client.removeCall(seq)
|
|
if call != nil {
|
|
call.Error = err
|
|
call.done()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *Client) Go(serviceMethod string,
|
|
args, reply interface{}, done chan *Call) *Call {
|
|
if done == nil {
|
|
done = make(chan *Call, 10)
|
|
} else if cap(done) == 0 {
|
|
log.Panic("rpc client: done channel is unbuffered")
|
|
}
|
|
call := &Call{
|
|
ServiceMethod: serviceMethod,
|
|
Args: args,
|
|
Reply: reply,
|
|
Done: done,
|
|
}
|
|
go client.send(call)
|
|
return call
|
|
}
|
|
|
|
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
|
|
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
|
|
select {
|
|
case <-ctx.Done():
|
|
client.removeCall(call.Seq)
|
|
return fmt.Errorf("rpc: client call failed: %v", ctx.Err())
|
|
case c := <-call.Done:
|
|
return c.Error
|
|
}
|
|
}
|
|
|
|
func dialTimeout(f newClientFunc, network, address string,
|
|
opts ...Option) (client *Client, err error) {
|
|
|
|
o := apply(DefaultOptions, opts...)
|
|
conn, err := net.DialTimeout(network, address, o.ConnectTimeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
ch := make(chan clientResult)
|
|
go func() {
|
|
client, err := f(conn, opts...)
|
|
ch <- clientResult{client: client, err: err}
|
|
}()
|
|
if o.ConnectTimeout == 0 {
|
|
result := <-ch
|
|
return result.client, result.err
|
|
}
|
|
select {
|
|
case <-time.After(o.ConnectTimeout):
|
|
return nil, fmt.Errorf("rpc: client connection timeout after %s", o.ConnectTimeout)
|
|
case result := <-ch:
|
|
return result.client, result.err
|
|
}
|
|
}
|