package krpc import ( "bufio" "context" "encoding/json" "errors" "fmt" "io" "log" "net" "net/http" "strings" "sync" "sync/atomic" "time" "krwu.top/krpc.v1/codec" ) type Call struct { Seq uint64 ServiceMethod string Args interface{} Reply interface{} Error error Done chan *Call } func (call *Call) done() { call.Done <- call } type Client struct { cc codec.Codec opt *Options sending sync.Mutex header codec.Header mu sync.Mutex seq uint64 pending map[uint64]*Call closing bool shutdown bool } type clientResult struct { client *Client err error } type newClientFunc func(conn net.Conn, opts ...Option) (client *Client, err error) var _ io.Closer = (*Client)(nil) var ErrShutdown = errors.New("connection is shutdown") func NewClient(conn net.Conn, opts ...Option) (*Client, error) { for i := range opts { opts[i](DefaultOptions) } f := codec.NewCodecFuncMap[DefaultOptions.CodecType] if f == nil { err := fmt.Errorf("invalid codec type %s", DefaultOptions.CodecType) log.Println("rpc client: codec error: ", err) return nil, err } if err := json.NewEncoder(conn).Encode(DefaultOptions); err != nil { log.Println("rpc client: options error:", err) _ = conn.Close() return nil, err } return newClientCodec(f(conn), DefaultOptions), nil } func Dial(network, address string, opts ...Option) (client *Client, err error) { return dialTimeout(NewClient, network, address, opts...) } func newClientCodec(cc codec.Codec, opts *Options) *Client { client := &Client{ seq: 1, cc: cc, opt: opts, pending: make(map[uint64]*Call), } go client.receive() return client } func (client *Client) Close() error { client.mu.Lock() defer client.mu.Unlock() if client.closing { return ErrShutdown } client.closing = true return client.cc.Close() } func (client *Client) IsAvailable() bool { client.mu.Lock() defer client.mu.Unlock() return !client.shutdown && !client.closing } func (client *Client) registerCall(call *Call) (uint64, error) { client.mu.Lock() defer client.mu.Unlock() if client.closing || client.shutdown { return 0, ErrShutdown } call.Seq = client.seq client.pending[call.Seq] = call atomic.StoreUint64(&client.seq, client.seq+1) return call.Seq, nil } func (client *Client) removeCall(seq uint64) *Call { client.mu.Lock() defer client.mu.Unlock() call := client.pending[seq] delete(client.pending, seq) return call } func (client *Client) terminateCalls(err error) { client.sending.Lock() defer client.sending.Unlock() client.mu.Lock() defer client.mu.Unlock() client.shutdown = true for i := range client.pending { client.pending[i].Error = err client.pending[i].done() } } func (client *Client) receive() { var err error for err == nil { var h codec.Header if err = client.cc.ReadHeader(&h); err != nil { break } call := client.removeCall(h.Seq) switch { case call == nil: err = client.cc.ReadBody(nil) case h.Error != "": call.Error = fmt.Errorf(h.Error) err = client.cc.ReadBody(nil) call.done() default: err = client.cc.ReadBody(call.Reply) if err != nil { call.Error = fmt.Errorf("read body %s", err.Error()) } call.done() } } client.terminateCalls(err) } func (client *Client) send(call *Call) { client.sending.Lock() defer client.sending.Unlock() seq, err := client.registerCall(call) if err != nil { call.Error = err call.done() return } client.header.ServiceMethod = call.ServiceMethod client.header.Seq = seq client.header.Error = "" if err := client.cc.Write(&client.header, call.Args); err != nil { call := client.removeCall(seq) if call != nil { call.Error = err call.done() } } } func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { if done == nil { done = make(chan *Call, 10) } else if cap(done) == 0 { log.Panic("rpc client: done channel is unbuffered") } call := &Call{ ServiceMethod: serviceMethod, Args: args, Reply: reply, Done: done, } go client.send(call) return call } func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error { call := client.Go(serviceMethod, args, reply, make(chan *Call, 1)) select { case <-ctx.Done(): client.removeCall(call.Seq) return fmt.Errorf("rpc: client call failed: %v", ctx.Err()) case c := <-call.Done: return c.Error } } func dialTimeout(f newClientFunc, network, address string, opts ...Option) (client *Client, err error) { o := apply(DefaultOptions, opts...) conn, err := net.DialTimeout(network, address, o.ConnectTimeout) if err != nil { return nil, err } defer func() { if err != nil { _ = conn.Close() } }() ch := make(chan clientResult) go func() { client, err := f(conn, opts...) ch <- clientResult{client: client, err: err} }() if o.ConnectTimeout == 0 { result := <-ch return result.client, result.err } select { case <-time.After(o.ConnectTimeout): return nil, fmt.Errorf("rpc: client connection timeout after %s", o.ConnectTimeout) case result := <-ch: return result.client, result.err } } func NewHTTPClient(conn net.Conn, opts ...Option) (*Client, error) { _, _ = fmt.Fprintf(conn, "CONNECT %s HTTP/1.0\n\n", defaultRPCPath) resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: http.MethodConnect}) if err == nil && resp.Status == connected { return NewClient(conn, opts...) } return nil, err } func DialHTTP(network, address string, opts ...Option) (*Client, error) { return dialTimeout(NewHTTPClient, network, address, opts...) } func XDial(rpcAddr string, opts ...Option) (*Client, error) { parts := strings.Split(rpcAddr, "://") if len(parts) != 2 { return nil, fmt.Errorf("rpc: client wrong format '%s', expect protocol://addr", rpcAddr) } protocol, addr := parts[0], parts[1] switch protocol { case "http": return DialHTTP("tcp", addr, opts...) default: return Dial(protocol, addr, opts...) } }