diff --git a/diam/message.go b/diam/message.go index 73c2d72..176db0f 100644 --- a/diam/message.go +++ b/diam/message.go @@ -64,10 +64,15 @@ func readerBufferSlice(buf *bytes.Buffer, l int) []byte { // ReadMessage reads a binary stream from the reader and uses the given // dictionary to parse it. -func ReadMessage(reader io.Reader, dictionary *dict.Parser) (*Message, error) { +func ReadMessage(reader io.Reader, dictionary *dict.Parser, hook func(*Message) error) (*Message, error) { buf := newReaderBuffer() defer putReaderBuffer(buf) m := &Message{dictionary: dictionary} + + if err := hook(m); err != nil { + return nil, err + } + cmd, stream, err := m.readHeader(reader, buf) if err != nil { return nil, err @@ -98,6 +103,7 @@ func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer) (cmd *dict.Command, if err != nil { return nil, stream, err } + m.Header, err = DecodeHeader(b) if err != nil { return nil, stream, err diff --git a/diam/server.go b/diam/server.go index 44514bc..0e44d9c 100644 --- a/diam/server.go +++ b/diam/server.go @@ -162,12 +162,21 @@ func (c *conn) readMessage() (m *Message, err error) { if c.server.ReadTimeout > 0 { c.rwc.SetReadDeadline(time.Now().Add(c.server.ReadTimeout)) } + + wrappedMethod := func(m *Message) error { + if c.server.ReadMessageHook == nil { + return nil + } + + return c.server.ReadMessageHook(c.writer, m) + } + if msc, isMulti := c.rwc.(MultistreamConn); isMulti { // If it's a multi-stream association - reset the stream to "undefined" prior to reading next message msc.ResetCurrentStream() - m, err = ReadMessage(msc, c.dictionary()) // MultistreamConn has it's own buffering + m, err = ReadMessage(msc, c.dictionary(), wrappedMethod) // MultistreamConn has it's own buffering } else { - m, err = ReadMessage(c.buf.Reader, c.dictionary()) + m, err = ReadMessage(c.buf.Reader, c.dictionary(), wrappedMethod) } if err != nil { return nil, err @@ -557,16 +566,19 @@ func Serve(l net.Listener, handler Handler) error { return srv.Serve(l) } +type ReadMessageHook = func(Conn, *Message) error + // A Server defines parameters for running a diameter server. type Server struct { - Network string // network of the address - empty string defaults to tcp - Addr string // address to listen on, ":3868" if empty - Handler Handler // handler to invoke, DefaultServeMux if nil - Dict *dict.Parser // diameter dictionaries for this server - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS - LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to + Network string // network of the address - empty string defaults to tcp + Addr string // address to listen on, ":3868" if empty + Handler Handler // handler to invoke, DefaultServeMux if nil + Dict *dict.Parser // diameter dictionaries for this server + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to + ReadMessageHook ReadMessageHook // optional Called right before ReadMessage method. } // serverHandler delegates to either the server's Handler or DefaultServeMux.