diff --git a/client/option.go b/client/option.go new file mode 100644 index 0000000..95f1465 --- /dev/null +++ b/client/option.go @@ -0,0 +1,21 @@ +package client + +import "krwu.top/krpc.v1/codec" + +type Options struct { + MagicNumber int + CodecType codec.Type +} + +var DefaultOptions = &Options{ + MagicNumber: codec.MagicNumber, + CodecType: codec.GobType, +} + +type Option func(*Options) + +func WithCodecType(t codec.Type) Option { + return func(o *Options) { + o.CodecType = t + } +} diff --git a/codec/codec.go b/codec/codec.go new file mode 100644 index 0000000..64b5d71 --- /dev/null +++ b/codec/codec.go @@ -0,0 +1,34 @@ +package codec + +import "io" + +type Header struct { + ServiceMethod string + Seq uint64 + Error string +} + +type Codec interface { + io.Closer + ReadHeader(*Header) error + ReadBody(interface{}) error + Write(*Header, interface{}) error +} + +type NewCodecFunc func(closer io.ReadWriteCloser) Codec + +type Type string + +const ( + MagicNumber = 0x12E7165 + GobType Type = "application/gob" + JsonType Type = "application/json" +) + +var NewCodecFuncMap map[Type]NewCodecFunc + +func init() { + NewCodecFuncMap = make(map[Type]NewCodecFunc) + NewCodecFuncMap[GobType] = NewGobCodec + NewCodecFuncMap[JsonType] = NewJsonCodec +} diff --git a/codec/gob.go b/codec/gob.go new file mode 100644 index 0000000..3400726 --- /dev/null +++ b/codec/gob.go @@ -0,0 +1,56 @@ +package codec + +import ( + "bufio" + "encoding/gob" + "fmt" + "io" +) + +type GobCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *gob.Decoder + enc *gob.Encoder +} + +var _ Codec = (*GobCodec)(nil) + +func NewGobCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &GobCodec{ + conn: conn, + buf: buf, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + } +} + +func (c *GobCodec) ReadHeader(h *Header) error { + return c.dec.Decode(h) +} + +func (c *GobCodec) ReadBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *GobCodec) Write(h *Header, body interface{}) (err error) { + defer func() { + _ = c.buf.Flush() + if err != nil { + _ = c.Close() + } + }() + + if err = c.enc.Encode(h); err != nil { + return fmt.Errorf("rpc codec: gob error encoding header: %v", err) + } + if err = c.enc.Encode(body); err != nil { + return fmt.Errorf("rpc codec: gob error encoding body: %v", err) + } + return nil +} + +func (c *GobCodec) Close() error { + return c.conn.Close() +} diff --git a/codec/json.go b/codec/json.go new file mode 100644 index 0000000..92c8177 --- /dev/null +++ b/codec/json.go @@ -0,0 +1,55 @@ +package codec + +import ( + "bufio" + "encoding/json" + "fmt" + "io" +) + +type JsonCodec struct { + conn io.ReadWriteCloser + buf *bufio.Writer + dec *json.Decoder + enc *json.Encoder +} + +func (j *JsonCodec) Close() error { + return j.conn.Close() +} + +func (j *JsonCodec) ReadHeader(h *Header) error { + return j.dec.Decode(h) +} + +func (j *JsonCodec) ReadBody(body interface{}) error { + return j.dec.Decode(body) +} + +func (j *JsonCodec) Write(header *Header, body interface{}) (err error) { + defer func() { + _ = j.buf.Flush() + if err != nil { + _ = j.Close() + } + }() + if err = j.enc.Encode(header); err != nil { + return fmt.Errorf("rpc codec: json error encoding header: %v", err) + } + if err = j.enc.Encode(body); err != nil { + return fmt.Errorf("rpc codec: json error encoding body: %v", err) + } + return nil +} + +var _ Codec = (*JsonCodec)(nil) + +func NewJsonCodec(conn io.ReadWriteCloser) Codec { + buf := bufio.NewWriter(conn) + return &JsonCodec{ + conn: conn, + buf: buf, + dec: json.NewDecoder(conn), + enc: json.NewEncoder(buf), + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c931dba --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module krwu.top/krpc.v1 + +go 1.17 + +require github.com/stretchr/testify v1.7.1 + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.1.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5786f2b --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/request.go b/request.go new file mode 100644 index 0000000..ffd1ebc --- /dev/null +++ b/request.go @@ -0,0 +1,47 @@ +package krpc + +import ( + "fmt" + "io" + "reflect" + "sync" + + "krwu.top/krpc.v1/codec" +) + +type request struct { + h *codec.Header + argv reflect.Value + replyv reflect.Value +} + +func (s *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) { + var h codec.Header + if err := cc.ReadHeader(&h); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + fmt.Println("rpc server: read header error:", err) + } + return nil, err + } + return &h, nil +} +func (s *Server) readRequest(cc codec.Codec) (*request, error) { + h, err := s.readRequestHeader(cc) + if err != nil { + fmt.Println("rpc server: read request error: ", err) + return nil, err + } + req := &request{h: h} + req.argv = reflect.New(reflect.TypeOf("")) + if err = cc.ReadBody(req.argv.Interface()); err != nil { + fmt.Println("rpc server: read argv err: ", err) + } + return req, nil +} + +func (s *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) { + defer wg.Done() + fmt.Println("rcp server: ", req.h, req.argv.Elem()) + req.replyv = reflect.ValueOf(fmt.Sprintf("krpc resp %d", req.h.Seq)) + s.sendResponse(cc, req.h, req.replyv.Interface(), sending) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..9d0e8cc --- /dev/null +++ b/response.go @@ -0,0 +1,18 @@ +package krpc + +import ( + "fmt" + "sync" + + "krwu.top/krpc.v1/codec" +) + +func (s *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) { + sending.Lock() + defer sending.Unlock() + + if err := cc.Write(h, body); err != nil { + fmt.Println("rpc server: write response error: ", err) + } + +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..8382c26 --- /dev/null +++ b/server.go @@ -0,0 +1,77 @@ +package krpc + +import ( + "encoding/json" + "fmt" + "io" + "net" + "sync" + + "krwu.top/krpc.v1/client" + "krwu.top/krpc.v1/codec" +) + +type Server struct{} + +func NewServer() *Server { + return &Server{} +} + +var DefaultServer = NewServer() + +func (s *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + fmt.Println("rpc server: accept error: ", err) + return + } + go s.ServeConn(conn) + } +} + +func (s *Server) ServeConn(conn io.ReadWriteCloser) { + defer func() { + _ = conn.Close() + }() + var opts client.Options + if err := json.NewDecoder(conn).Decode(&opts); err != nil { + fmt.Println("rpc server: options error: ", err) + return + } + if opts.MagicNumber != codec.MagicNumber { + fmt.Printf("rpc server: invalid magic number %x\n", opts.MagicNumber) + return + } + f := codec.NewCodecFuncMap[opts.CodecType] + if f == nil { + fmt.Printf("rpc server: invalid codec type %s\n", opts.CodecType) + return + } + s.ServeCodec(f(conn)) +} + +var invalidRequest = struct{}{} + +func (s *Server) ServeCodec(cc codec.Codec) { + sending := new(sync.Mutex) + wg := new(sync.WaitGroup) + for { + req, err := s.readRequest(cc) + if err != nil { + fmt.Println("rpc server invalid request: ", err) + if req == nil { + break + } + req.h.Error = err.Error() + s.sendResponse(cc, req.h, invalidRequest, sending) + continue + } + wg.Add(1) + go s.handleRequest(cc, req, sending, wg) + } + wg.Wait() + _ = cc.Close() +} + +func Accept(lis net.Listener) { DefaultServer.Accept(lis) }