Files
krpc/client.go

244 lines
5.1 KiB
Go

package krpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"krwu.top/krpc.v1/codec"
)
type Call struct {
Seq uint64
ServiceMethod string
Args interface{}
Reply interface{}
Error error
Done chan *Call
}
func (call *Call) done() {
call.Done <- call
}
type Client struct {
cc codec.Codec
opt *Options
sending sync.Mutex
header codec.Header
mu sync.Mutex
seq uint64
pending map[uint64]*Call
closing bool
shutdown bool
}
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opts ...Option) (client *Client, err error)
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shutdown")
func NewClient(conn net.Conn, opts ...Option) (*Client, error) {
for i := range opts {
opts[i](DefaultOptions)
}
f := codec.NewCodecFuncMap[DefaultOptions.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", DefaultOptions.CodecType)
log.Println("rpc client: codec error: ", err)
return nil, err
}
if err := json.NewEncoder(conn).Encode(DefaultOptions); err != nil {
log.Println("rpc client: options error:", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), DefaultOptions), nil
}
func Dial(network, address string, opts ...Option) (client *Client, err error) {
return dialTimeout(NewClient, network, address, opts...)
}
func newClientCodec(cc codec.Codec, opts *Options) *Client {
client := &Client{
seq: 1,
cc: cc,
opt: opts,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
atomic.StoreUint64(&client.seq, client.seq+1)
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for i := range client.pending {
client.pending[i].Error = err
client.pending[i].done()
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
case call == nil:
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = fmt.Errorf("read body %s", err.Error())
}
call.done()
}
}
client.terminateCalls(err)
}
func (client *Client) send(call *Call) {
client.sending.Lock()
defer client.sending.Unlock()
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) Go(serviceMethod string,
args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
go client.send(call)
return call
}
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return fmt.Errorf("rpc: client call failed: %v", ctx.Err())
case c := <-call.Done:
return c.Error
}
}
func dialTimeout(f newClientFunc, network, address string,
opts ...Option) (client *Client, err error) {
o := apply(DefaultOptions, opts...)
conn, err := net.DialTimeout(network, address, o.ConnectTimeout)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opts...)
ch <- clientResult{client: client, err: err}
}()
if o.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(o.ConnectTimeout):
return nil, fmt.Errorf("rpc: client connection timeout after %s", o.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}