diff --git a/accepter_test.go b/accepter_test.go deleted file mode 100644 index 54bfff845..000000000 --- a/accepter_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) quickfixengine.org All rights reserved. -// -// This file may be distributed under the terms of the quickfixengine.org -// license as defined by quickfixengine.org and appearing in the file -// LICENSE included in the packaging of this file. -// -// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING -// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A -// PARTICULAR PURPOSE. -// -// See http://www.quickfixengine.org/LICENSE for licensing information. -// -// Contact ask@quickfixengine.org if any conditions of this licensing -// are not clear to you. - -package quickfix - -import ( - "net" - "testing" - - "github.com/quickfixgo/quickfix/config" - - proxyproto "github.com/pires/go-proxyproto" - "github.com/stretchr/testify/assert" -) - -func TestAcceptor_Start(t *testing.T) { - sessionSettings := NewSessionSettings() - sessionSettings.Set(config.BeginString, BeginStringFIX42) - sessionSettings.Set(config.SenderCompID, "sender") - sessionSettings.Set(config.TargetCompID, "target") - - settingsWithTCPProxy := NewSettings() - settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") - - settingsWithNoTCPProxy := NewSettings() - settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") - - genericSettings := NewSettings() - - const ( - GenericListener = iota - ProxyListener - ) - - acceptorStartTests := []struct { - name string - settings *Settings - listenerType int - }{ - {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, - {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, - {"no TCP proxy configuration set", genericSettings, GenericListener}, - } - - for _, tt := range acceptorStartTests { - t.Run(tt.name, func(t *testing.T) { - tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") - if _, err := tt.settings.AddSession(sessionSettings); err != nil { - assert.Nil(t, err) - } - - acceptor := &Acceptor{settings: tt.settings} - if err := acceptor.Start(); err != nil { - assert.NotNil(t, err) - } - assert.Len(t, acceptor.listeners, 1) - - for _, listener := range acceptor.listeners { - if tt.listenerType == ProxyListener { - _, ok := listener.(*proxyproto.Listener) - assert.True(t, ok) - } - - if tt.listenerType == GenericListener { - _, ok := listener.(*net.TCPListener) - assert.True(t, ok) - } - } - - acceptor.Stop() - }) - } -} diff --git a/acceptor.go b/acceptor.go index f5b9b281c..d7f5eda51 100644 --- a/acceptor.go +++ b/acceptor.go @@ -32,22 +32,25 @@ import ( // Acceptor accepts connections from FIX clients and manages the associated sessions. type Acceptor struct { - app Application - settings *Settings - logFactory LogFactory - storeFactory MessageStoreFactory - globalLog Log - sessions map[SessionID]*session - sessionGroup sync.WaitGroup - listenerShutdown sync.WaitGroup - dynamicSessions bool - dynamicQualifier bool - dynamicQualifierCount int - dynamicSessionChan chan *session - sessionAddr sync.Map - sessionHostPort map[SessionID]int - listeners map[string]net.Listener - connectionValidator ConnectionValidator + app Application + settings *Settings + logFactory LogFactory + storeFactory MessageStoreFactory + globalLog Log + sessions map[SessionID]*session + sessionsLock sync.RWMutex + sessionGroup sync.WaitGroup + listenerShutdown sync.WaitGroup + dynamicSessions bool + dynamicQualifier bool + dynamicQualifierCount int + dynamicSessionChan chan *session + sessionAddr sync.Map + sessionHostPort map[SessionID]int + listeners map[string]net.Listener + connectionValidator ConnectionValidator + templateIDProvider TemplateIDProvider + dynamicAcceptorSessionProvider *dynamicAcceptorSessionProvider sessionFactory } @@ -60,6 +63,11 @@ type ConnectionValidator interface { // Start accepting connections. func (a *Acceptor) Start() (err error) { + + if err = a.configureDyanmicSessionProvider(); err != nil { + return + } + socketAcceptHost := "" if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) { if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil { @@ -104,14 +112,8 @@ func (a *Acceptor) Start() (err error) { a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]} } } + a.startSessions() - for _, s := range a.sessions { - a.sessionGroup.Add(1) - go func(s *session) { - s.run() - a.sessionGroup.Done() - }(s) - } if a.dynamicSessions { a.dynamicSessionChan = make(chan *session) a.sessionGroup.Add(1) @@ -127,6 +129,25 @@ func (a *Acceptor) Start() (err error) { return } +func (a *Acceptor) configureDyanmicSessionProvider() error { + if a.templateIDProvider == nil { + defaultTemplateIDProvider, err := NewDefaultTemplateIDProvider(a.settings) + if err != nil { + return err + } + if len(defaultTemplateIDProvider.templateMappings) == 0 { + // no templateMappings + return nil + } + a.templateIDProvider = defaultTemplateIDProvider + } + if setter, ok := a.storeFactory.(TemplateIDProviderSetter); ok { + setter.SetTemplateIDProvider(a.templateIDProvider) + } + a.dynamicAcceptorSessionProvider = newDynamicAcceptorSessionProvider(a.settings, a.storeFactory, a.logFactory, a.app, a.templateIDProvider) + return nil +} + // Stop logs out existing sessions, close their connections, and stop accepting new connections. func (a *Acceptor) Stop() { defer func() { @@ -140,17 +161,7 @@ func (a *Acceptor) Stop() { if a.dynamicSessions { close(a.dynamicSessionChan) } - for _, session := range a.sessions { - session.stop() - } - a.sessionGroup.Wait() - - for sessionID := range a.sessions { - err := UnregisterSession(sessionID) - if err != nil { - return - } - } + a.stopSessions() } // RemoteAddr gets remote IP address for a given session. @@ -191,6 +202,15 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se } for sessionID, sessionSettings := range settings.SessionSettings() { + if sessionSettings.HasSetting(config.AcceptorTemplate) { + var acceptorTemplate bool + if acceptorTemplate, err = sessionSettings.BoolSetting(config.AcceptorTemplate); err != nil { + return + } + if acceptorTemplate { + continue + } + } sessID := sessionID sessID.Qualifier = "" @@ -331,18 +351,27 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { } session, ok := a.sessions[sessID] if !ok { - if !a.dynamicSessions { - a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) - return - } - dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app) - if err != nil { - a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err) - return + if a.dynamicAcceptorSessionProvider != nil { + session, err = a.dynamicAcceptorSessionProvider.GetSession(sessID) + if err != nil { + a.globalLog.OnEventf("Failed to get session %v from provider: %v", sessID, err) + return + } + a.addMngdDynamicSession(sessID, session) + } else { + if !a.dynamicSessions { + a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) + return + } + dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app) + if err != nil { + a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err) + return + } + a.dynamicSessionChan <- dynamicSession + session = dynamicSession + defer session.stop() } - a.dynamicSessionChan <- dynamicSession - session = dynamicSession - defer session.stop() } a.sessionAddr.Store(sessID, netConn.RemoteAddr()) @@ -412,6 +441,46 @@ LOOP: } } +func (a *Acceptor) startSessions() { + a.sessionsLock.RLock() + defer a.sessionsLock.RUnlock() + for _, s := range a.sessions { + a.sessionGroup.Add(1) + go func(s *session) { + s.run() + a.sessionGroup.Done() + }(s) + } +} + +func (a *Acceptor) stopSessions() { + a.sessionsLock.RLock() + defer a.sessionsLock.RUnlock() + for _, session := range a.sessions { + session.stop() + } + a.sessionGroup.Wait() + + for sessionID := range a.sessions { + err := UnregisterSession(sessionID) + if err != nil { + return + } + } +} + +func (a *Acceptor) addMngdDynamicSession(sessID SessionID, session *session) { + a.sessionsLock.Lock() + defer a.sessionsLock.Unlock() + + a.sessions[sessID] = session + a.sessionGroup.Add(1) + go func() { + session.run() + a.sessionGroup.Done() + }() +} + // SetConnectionValidator sets an optional connection validator. // Use it when you need a custom authentication logic that includes lower level interactions, // like mTLS auth or IP whitelistening. @@ -421,3 +490,9 @@ LOOP: func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { a.connectionValidator = validator } + +// SetTemplateIDProvider sets an optional templateID provider. +// If not set and AcceptorTemplate=Y is configured for a session, the `DefaultTemplateIDProvider` will be used. +func (a *Acceptor) SetTemplateIDProvider(templateIDProvider TemplateIDProvider) { + a.templateIDProvider = templateIDProvider +} diff --git a/acceptor_session_provider.go b/acceptor_session_provider.go new file mode 100644 index 000000000..7945ebf2c --- /dev/null +++ b/acceptor_session_provider.go @@ -0,0 +1,140 @@ +package quickfix + +import "github.com/quickfixgo/quickfix/config" + +const ( + WildcardPattern string = "*" +) + +// TemplateIDProvider is an interface for obtaining templateIDs for inbound sessions. +// +// The SessionSettings for the template SessionID must be configured in the Acceptor. +// The sessionSettings of an inbound session inherits from the template SessionID. +// If no matching template is found, return nil, and no session will be created for the inbound logon request. +type TemplateIDProvider interface { + GetTemplateID(inbound SessionID) (templateID *SessionID) +} + +type TemplateIDProviderSetter interface { + SetTemplateIDProvider(TemplateIDProvider) +} + +type DefaultTemplateIDProvider struct { + templateMappings []*TemplateMapping +} + +func NewDefaultTemplateIDProvider(settings *Settings) (*DefaultTemplateIDProvider, error) { + templateMappings := make([]*TemplateMapping, 0) + for sid, ss := range settings.SessionSettings() { + var ( + acceptorTemplate bool + err error + ) + if ss.HasSetting(config.AcceptorTemplate) { + acceptorTemplate, err = ss.BoolSetting(config.AcceptorTemplate) + if err != nil { + return nil, err + } + } + if acceptorTemplate { + templateMappings = append(templateMappings, &TemplateMapping{ + Pattern: sid, + TemplateID: sid, + }) + } + } + return &DefaultTemplateIDProvider{templateMappings: templateMappings}, nil +} + +func (p *DefaultTemplateIDProvider) GetTemplateID(inbound SessionID) (templateID *SessionID) { + return p.lookupTemplateID(inbound) +} + +func (p *DefaultTemplateIDProvider) lookupTemplateID(sessionID SessionID) *SessionID { + for _, mapping := range p.templateMappings { + if isTemplateMatching(mapping.Pattern, sessionID) { + return &mapping.TemplateID + } + } + return nil +} + +type dynamicAcceptorSessionProvider struct { + settings *Settings + messageStoreFactory MessageStoreFactory + logFactory LogFactory + sessionFactory *sessionFactory + application Application + + templateIDProvider TemplateIDProvider +} + +func newDynamicAcceptorSessionProvider(settings *Settings, messageStoreFactory MessageStoreFactory, logFactory LogFactory, + application Application, templateIDProvider TemplateIDProvider, +) *dynamicAcceptorSessionProvider { + return &dynamicAcceptorSessionProvider{ + settings: settings, + messageStoreFactory: messageStoreFactory, + logFactory: logFactory, + sessionFactory: &sessionFactory{}, + application: application, + templateIDProvider: templateIDProvider, + } +} + +func (p *dynamicAcceptorSessionProvider) GetSession(sessionID SessionID) (*session, error) { + s, ok := lookupSession(sessionID) + if !ok && p.templateIDProvider != nil { + templateID := p.templateIDProvider.GetTemplateID(sessionID) + if templateID == nil { + return nil, errUnknownSession + } + dynamicSessionSettings := p.settings.globalSettings.clone() + templateSettings, ok := p.settings.sessionSettings[*templateID] + if !ok { + return nil, errUnknownSession + } + dynamicSessionSettings.overlay(templateSettings) + dynamicSessionSettings.Set(config.BeginString, sessionID.BeginString) + dynamicSessionSettings.Set(config.SenderCompID, sessionID.SenderCompID) + dynamicSessionSettings.Set(config.SenderSubID, sessionID.SenderSubID) + dynamicSessionSettings.Set(config.SenderLocationID, sessionID.SenderLocationID) + dynamicSessionSettings.Set(config.TargetCompID, sessionID.TargetCompID) + dynamicSessionSettings.Set(config.TargetSubID, sessionID.TargetSubID) + dynamicSessionSettings.Set(config.TargetLocationID, sessionID.TargetLocationID) + var err error + s, err = p.sessionFactory.createSession(sessionID, + p.messageStoreFactory, + dynamicSessionSettings, + p.logFactory, + p.application, + ) + if err != nil { + return nil, err + } + } + if s == nil { + return nil, errUnknownSession + } + return s, nil +} + +func isTemplateMatching(pattern SessionID, sessionID SessionID) bool { + return matches(pattern.BeginString, sessionID.BeginString) && + matches(pattern.SenderCompID, sessionID.SenderCompID) && + matches(pattern.SenderSubID, sessionID.SenderSubID) && + matches(pattern.SenderLocationID, sessionID.SenderLocationID) && + matches(pattern.TargetCompID, sessionID.TargetCompID) && + matches(pattern.TargetSubID, sessionID.TargetSubID) && + matches(pattern.TargetLocationID, sessionID.TargetLocationID) +} + +func matches(pattern string, value string) bool { + return WildcardPattern == pattern || pattern == value +} + +// TemplateMapping mapping from a sessionID pattern to a session template ID. +type TemplateMapping struct { + Pattern SessionID + TemplateID SessionID +} diff --git a/acceptor_session_provider_test.go b/acceptor_session_provider_test.go new file mode 100644 index 000000000..8a84e70ff --- /dev/null +++ b/acceptor_session_provider_test.go @@ -0,0 +1,252 @@ +package quickfix + +import ( + "strings" + "testing" + + "github.com/quickfixgo/quickfix/config" + "github.com/stretchr/testify/suite" +) + +var _ Application = &noopApp{} + +type noopApp struct { +} + +func (n *noopApp) FromAdmin(_ *Message, _ SessionID) MessageRejectError { + return nil +} + +func (n *noopApp) FromApp(_ *Message, _ SessionID) MessageRejectError { + return nil +} + +func (n *noopApp) OnCreate(_ SessionID) { +} + +func (n *noopApp) OnLogon(_ SessionID) { +} + +func (n *noopApp) OnLogout(_ SessionID) { +} + +func (n *noopApp) ToAdmin(_ *Message, _ SessionID) { +} + +func (n *noopApp) ToApp(_ *Message, _ SessionID) error { + return nil +} + +type DynamicAcceptorSessionProviderTestSuite struct { + suite.Suite + + dynamicAcceptorSessionProvider *dynamicAcceptorSessionProvider + + settings *Settings + messageStoreFactory MessageStoreFactory + logFactory LogFactory + app Application + sessionFactory *sessionFactory +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) TestNewDefaultTemplateIDProvider() { + cfg := ` +[default] +ConnectionType=acceptor +SocketAcceptPort=9878 +BeginString=FIX.4.2 +TimeZone=America/New_York +StartTime=00:00:01 +EndTime=23:59:59 +HeartBtInt=30 + +[session] +AcceptorTemplate=Y +SenderCompID=test1 +TargetCompID=* +ResetOnLogon=Y + +[session] +AcceptorTemplate=Y +SenderCompID=test2 +TargetCompID=* +ResetOnLogon=Y + ` + stringReader := strings.NewReader(cfg) + settings, err := ParseSettings(stringReader) + if err != nil { + suite.FailNow("parse setting failed", err) + } + + provider, err := NewDefaultTemplateIDProvider(settings) + if err != nil { + suite.FailNow("create TemplateIDProvider failed", err) + } + + s1 := SessionID{BeginString: "FIX.4.2", SenderCompID: "test1", TargetCompID: "cli"} + templateID1 := provider.GetTemplateID(s1) + suite.Require().NotNil(templateID1, "expected template matched") + suite.Require().Equal( + *templateID1, + SessionID{BeginString: "FIX.4.2", SenderCompID: "test1", TargetCompID: "*"}, + "unexpected templateID", + ) + + s2 := SessionID{BeginString: "FIX.4.3", SenderCompID: "test1", TargetCompID: "cli"} + templateID2 := provider.GetTemplateID(s2) + suite.Require().Nilf(templateID2, "expected template not matched for %v", s2.String()) + + s3 := SessionID{BeginString: "FIX.4.2", SenderCompID: "X", TargetCompID: "cli"} + templateID3 := provider.GetTemplateID(s3) + suite.Require().Nilf(templateID3, "expected template not matched for %v", s3.String()) +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) SetupTest() { + suite.settings = NewSettings() + suite.messageStoreFactory = NewMemoryStoreFactory() + suite.logFactory = nullLogFactory{} + suite.app = &noopApp{} + suite.sessionFactory = &sessionFactory{} + templateMappings := make([]*TemplateMapping, 0) + + templateID1 := SessionID{BeginString: "FIX.4.2", SenderCompID: "ANY", TargetCompID: "ANY"} + templateMappings = append( + templateMappings, + &TemplateMapping{Pattern: SessionID{BeginString: WildcardPattern, SenderCompID: "S1", TargetCompID: WildcardPattern}, TemplateID: templateID1}, + ) + suite.setUpSettings(templateID1, "ResetOnLogout", "Y") + + templateID2 := SessionID{BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "ANY"} + templateMappings = append( + templateMappings, + &TemplateMapping{Pattern: SessionID{BeginString: "FIX.4.4", SenderCompID: WildcardPattern, TargetCompID: WildcardPattern}, TemplateID: templateID2}, + ) + suite.setUpSettings(templateID2, "RefreshOnLogon", "Y") + + templateID3 := SessionID{BeginString: "FIX.4.4", SenderCompID: "ANY", TargetCompID: "ANY"} + templateMappings = append( + templateMappings, + &TemplateMapping{Pattern: SessionID{BeginString: "FIX.4.2", SenderCompID: WildcardPattern, SenderSubID: WildcardPattern, SenderLocationID: WildcardPattern, + TargetCompID: WildcardPattern, TargetSubID: WildcardPattern, TargetLocationID: WildcardPattern, Qualifier: WildcardPattern, + }, TemplateID: templateID3}, + ) + suite.setUpSettings(templateID3, "ResetOnDisconnect", "Y") + + templateIDProvider := &DefaultTemplateIDProvider{ + templateMappings: templateMappings, + } + + suite.dynamicAcceptorSessionProvider = newDynamicAcceptorSessionProvider(suite.settings, suite.messageStoreFactory, + suite.logFactory, suite.app, templateIDProvider) +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) setUpSettings(TemplateID SessionID, key, value string) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, TemplateID.BeginString) + sessionSettings.Set(config.SenderCompID, TemplateID.SenderCompID) + sessionSettings.Set(config.SenderSubID, TemplateID.SenderSubID) + sessionSettings.Set(config.SenderLocationID, TemplateID.SenderLocationID) + sessionSettings.Set(config.TargetCompID, TemplateID.TargetCompID) + sessionSettings.Set(config.TargetSubID, TemplateID.TargetSubID) + sessionSettings.Set(config.TargetLocationID, TemplateID.TargetLocationID) + sessionSettings.Set(config.SessionQualifier, TemplateID.Qualifier) + + sessionSettings.Set("StartTime", "00:00:00") + sessionSettings.Set("EndTime", "00:00:00") + sessionSettings.Set(key, value) + suite.settings.AddSession(sessionSettings) +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) TestSessionCreation() { + type expected struct { + sessionID SessionID + resetOnLogout bool + refreshOnLogon bool + resetOnDisconnect bool + } + var tests = []struct { + name string + input SessionID + expected expected + }{ + { + name: "session created - matched", + input: SessionID{ + BeginString: "FIX.4.2", SenderCompID: "SENDER", SenderSubID: "SENDERSUB", SenderLocationID: "SENDERLOC", + TargetCompID: "TARGET", TargetSubID: "TARGETSUB", TargetLocationID: "TARGETLOC", Qualifier: "", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.2", SenderCompID: "SENDER", SenderSubID: "SENDERSUB", SenderLocationID: "SENDERLOC", + TargetCompID: "TARGET", TargetSubID: "TARGETSUB", TargetLocationID: "TARGETLOC", Qualifier: "", + }, + resetOnLogout: false, + refreshOnLogon: false, + resetOnDisconnect: true, + }, + }, + { + name: "create session - matching the first", + input: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "T", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "S1", TargetCompID: "T", + }, + resetOnLogout: true, + refreshOnLogon: false, + resetOnDisconnect: false, + }, + }, + { + name: "create session - matching the second", + input: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "X", TargetCompID: "Y", + }, + expected: expected{ + sessionID: SessionID{ + BeginString: "FIX.4.4", SenderCompID: "X", TargetCompID: "Y", + }, + resetOnLogout: false, + refreshOnLogon: true, + resetOnDisconnect: false, + }, + }, + } + + for _, test := range tests { + session, err := suite.dynamicAcceptorSessionProvider.GetSession(test.input) + suite.NoError(err) + suite.NotNil(session) + sessionID := session.sessionID + suite.Require().Equal(test.expected.sessionID, sessionID, test.name+": created sessionID not expected") + suite.Require().Equal(test.expected.resetOnLogout, session.ResetOnLogout, test.name+":ResetOnLogout not expected") + suite.Require().Equal(test.expected.refreshOnLogon, session.RefreshOnLogon, test.name+":RefreshOnLogon not expected") + suite.Require().Equal(test.expected.resetOnDisconnect, session.ResetOnDisconnect, test.name+":ResetOnDisconnect not expected") + UnregisterSession(sessionID) + } +} + +func (suite *DynamicAcceptorSessionProviderTestSuite) TestTemplateNotFound() { + var tests = []struct { + name string + input SessionID + }{ + { + name: "template not found", + input: SessionID{ + BeginString: "FIX.4.3", SenderCompID: "S", TargetCompID: "T", + }, + }, + } + + for _, test := range tests { + _, err := suite.dynamicAcceptorSessionProvider.GetSession(test.input) + suite.Error(err, test.name+": expected error for template not found") + } +} + +func TestDynamicAcceptorSessionProviderTestSuite(t *testing.T) { + suite.Run(t, new(DynamicAcceptorSessionProviderTestSuite)) +} diff --git a/acceptor_test.go b/acceptor_test.go new file mode 100644 index 000000000..fd862d784 --- /dev/null +++ b/acceptor_test.go @@ -0,0 +1,390 @@ +// Copyright (c) quickfixengine.org All rights reserved. +// +// This file may be distributed under the terms of the quickfixengine.org +// license as defined by quickfixengine.org and appearing in the file +// LICENSE included in the packaging of this file. +// +// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING +// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A +// PARTICULAR PURPOSE. +// +// See http://www.quickfixengine.org/LICENSE for licensing information. +// +// Contact ask@quickfixengine.org if any conditions of this licensing +// are not clear to you. + +package quickfix + +import ( + "bytes" + "io" + "net" + "strings" + "testing" + "time" + + proxyproto "github.com/pires/go-proxyproto" + "github.com/quickfixgo/quickfix/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +func TestAcceptor_Start(t *testing.T) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, BeginStringFIX42) + sessionSettings.Set(config.SenderCompID, "sender") + sessionSettings.Set(config.TargetCompID, "target") + + settingsWithTCPProxy := NewSettings() + settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") + + settingsWithNoTCPProxy := NewSettings() + settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") + + genericSettings := NewSettings() + + const ( + GenericListener = iota + ProxyListener + ) + + acceptorStartTests := []struct { + name string + settings *Settings + listenerType int + }{ + {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, + {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, + {"no TCP proxy configuration set", genericSettings, GenericListener}, + } + + for _, tt := range acceptorStartTests { + t.Run(tt.name, func(t *testing.T) { + tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") + if _, err := tt.settings.AddSession(sessionSettings); err != nil { + assert.Nil(t, err) + } + + acceptor := &Acceptor{settings: tt.settings} + if err := acceptor.Start(); err != nil { + assert.NotNil(t, err) + } + assert.Len(t, acceptor.listeners, 1) + + for _, listener := range acceptor.listeners { + if tt.listenerType == ProxyListener { + _, ok := listener.(*proxyproto.Listener) + assert.True(t, ok) + } + + if tt.listenerType == GenericListener { + _, ok := listener.(*net.TCPListener) + assert.True(t, ok) + } + } + + acceptor.Stop() + }) + } +} + +var _ net.Conn = &mockConn{} + +type mockConn struct { + closeChan chan struct{} + localAddr net.Addr + remoteAddr net.Addr + + onWriteback func([]byte) + inboundMessages []*Message +} + +func (c *mockConn) Read(b []byte) (n int, err error) { + if len(c.inboundMessages) > 0 { + messageBytes := c.inboundMessages[0].build() + copy(b, messageBytes) + c.inboundMessages = c.inboundMessages[1:] + return len(messageBytes), err + } + <-c.closeChan + return 0, io.EOF +} + +func (c *mockConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *mockConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *mockConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *mockConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *mockConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +func (c *mockConn) Write(b []byte) (n int, err error) { + if c.onWriteback != nil { + c.onWriteback(b) + } + return len(b), nil +} + +func (c *mockConn) Close() error { + return nil +} + +func mockLogonMessage(sessionID SessionID, msgSeqNum int) *Message { + msg := NewMessage() + msg.Header.SetField(tagMsgType, FIXString("A")) + msg.Header.SetInt(tagMsgSeqNum, msgSeqNum) + msg.Header.SetString(tagBeginString, sessionID.BeginString) + msg.Header.SetString(tagSenderCompID, sessionID.SenderCompID) + msg.Header.SetString(tagSenderSubID, sessionID.SenderSubID) + msg.Header.SetString(tagSenderLocationID, sessionID.SenderLocationID) + msg.Header.SetString(tagTargetCompID, sessionID.TargetCompID) + msg.Header.SetString(tagTargetSubID, sessionID.TargetSubID) + msg.Header.SetString(tagTargetLocationID, sessionID.TargetLocationID) + msg.Header.SetField(tagSendingTime, FIXUTCTimestamp{Time: time.Now()}) + msg.Body.SetInt(tagHeartBtInt, 30) + return msg +} + +type AcceptorTemplateTestSuite struct { + suite.Suite + acceptor *Acceptor + + sessionID1 SessionID + sessionID2 SessionID + sessionID3 SessionID + + cliSessionID SessionID + logonSessionID SessionID + seqNum int +} + +func (suite *AcceptorTemplateTestSuite) BeforeTest(_, _ string) { + cfg := ` +[default] +ConnectionType=acceptor +SocketAcceptPort=5001 +TimeZone=America/New_York +StartTime=00:00:01 +EndTime=23:59:59 +HeartBtInt=30 + +[session] +BeginString=FIX.4.2 +SenderCompID=sender1 +TargetCompID=target1 + +[session] +BeginString=FIX.4.3 +SenderCompID=sender2 +TargetCompID=target2 +ResetOnLogon=Y + +[session] +AcceptorTemplate=Y +BeginString=FIX.4.3 +SenderCompID=* +SenderSubID=* +SenderLocationID=* +TargetCompID=target3 +TargetSubID=* +TargetLocationID=* +ResetOnLogout=Y +` + stringReader := strings.NewReader(cfg) + settings, err := ParseSettings(stringReader) + if err != nil { + suite.FailNow("parse setting failed", err) + } + suite.sessionID1 = SessionID{BeginString: BeginStringFIX42, SenderCompID: "sender1", TargetCompID: "target1"} + suite.sessionID2 = SessionID{BeginString: BeginStringFIX43, SenderCompID: "sender2", TargetCompID: "target2"} + suite.sessionID3 = SessionID{BeginString: BeginStringFIX43, SenderCompID: "*", SenderSubID: "*", SenderLocationID: "*", + TargetCompID: "target3", TargetSubID: "*", TargetLocationID: "*"} + + app := &noopApp{} + a, err := NewAcceptor(app, memoryStoreFactory{}, settings, NewNullLogFactory()) + if err != nil { + suite.Fail("Failed to create acceptor", err) + } + suite.acceptor = a + + suite.cliSessionID = SessionID{BeginString: BeginStringFIX43, SenderCompID: "target3", TargetCompID: "dynamicSender"} + suite.logonSessionID = SessionID{BeginString: BeginStringFIX43, SenderCompID: "dynamicSender", TargetCompID: "target3"} + suite.seqNum = 1 +} + +func (suite *AcceptorTemplateTestSuite) TearDownTest() { + suite.acceptor.Stop() + suite.acceptor = nil + suite.seqNum = 1 +} + +func (suite *AcceptorTemplateTestSuite) logonAndDisconnectAfterCheck(sessionID SessionID, + checkFuncAfterLogon func(), + wantLogonSuccess bool) { + inboundMessages := []*Message{mockLogonMessage(sessionID, suite.seqNum)} + suite.seqNum++ + var respondedLogonMessageReceived bool + mockConn1 := &mockConn{ + closeChan: make(chan struct{}), + inboundMessages: inboundMessages, + localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5001}, + remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5002}, + } + mockConn1.onWriteback = func(b []byte) { + responseMsg := NewMessage() + err := ParseMessage(responseMsg, bytes.NewBuffer(b)) + suite.Require().NoError(err, "parse responding message failed") + msgType, err := responseMsg.Header.GetString(tagMsgType) + suite.Require().NoError(err, "unexpected mssage") + if wantLogonSuccess && msgType != "A" { + return + } + respondedLogonMessageReceived = true + if checkFuncAfterLogon != nil { + checkFuncAfterLogon() + } + close(mockConn1.closeChan) + } + suite.acceptor.handleConnection(mockConn1) + if wantLogonSuccess { + suite.Require().Equal(true, respondedLogonMessageReceived, "expected responding logon message") + } +} + +func (suite *AcceptorTemplateTestSuite) verifySessionCount(expectedSessionCount int) { + suite.Require().Equalf(expectedSessionCount, len(suite.acceptor.sessions), "expected %v sessions but found %v", expectedSessionCount, len(suite.acceptor.sessions)) + suite.Require().Equalf(expectedSessionCount, len(sessions), "expected %v sessions but found %v in registry", expectedSessionCount, sessions) +} + +func (suite *AcceptorTemplateTestSuite) TestCreateDynamicSessionBySessionProvider() { + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + logonSessionID := suite.logonSessionID + suite.logonAndDisconnectAfterCheck(suite.cliSessionID, func() { + suite.verifySessionCount(3) + + createdSession, ok := suite.acceptor.sessions[logonSessionID] + suite.Require().Equal(true, ok, "expected dynamic session to be created") + suite.Require().Equal(logonSessionID, createdSession.sessionID, "expected session ID to match inbound session ID") + suite.Require().Equal(createdSession.ResetOnLogout, true, "expected ResetOnLogout=Y for createdSession") + + remoteAddr, ok := suite.acceptor.RemoteAddr(logonSessionID) + if !ok { + suite.Fail("Failed to get remote address for dynamic session") + } + suite.Require().Equal("127.0.0.1:5002", remoteAddr.String(), "expect remoteAddr for dynamic session to be 127.0.0.1:5002 but got %v", remoteAddr.String()) + }, true) +} + +func (suite *AcceptorTemplateTestSuite) TestSessionCreatedBySessionProviderShouldBeKept() { + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + logonSessionID := suite.logonSessionID + suite.logonAndDisconnectAfterCheck(suite.cliSessionID, func() { + suite.verifySessionCount(3) + }, true) + err := SendToTarget(createFIX43NewOrderSingle(), logonSessionID) + suite.NoError(err, "expected message can still be sent after session disconnected") +} + +func (suite *AcceptorTemplateTestSuite) TestNoNewSessionCreatedWhenSameSessionIDLogons() { + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + suite.logonAndDisconnectAfterCheck(suite.cliSessionID, func() { + suite.verifySessionCount(3) + }, true) + suite.logonAndDisconnectAfterCheck(suite.cliSessionID, func() { + suite.verifySessionCount(3) + }, true) + suite.logonAndDisconnectAfterCheck(suite.cliSessionID, func() { + suite.verifySessionCount(3) + }, true) +} + +func (suite *AcceptorTemplateTestSuite) TestSessionNotFoundBySessionProvider() { + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + sessionID := SessionID{BeginString: BeginStringFIX43, SenderCompID: "unknownSender", TargetCompID: "unknownTarget"} + suite.logonAndDisconnectAfterCheck(sessionID, func() {}, false) + suite.verifySessionCount(2) +} + +type mockCustomTemplateIDProvider struct { + staticTemplateID SessionID +} + +// mockCustomTemplateIDProvider always returns the same templateID. +func (p *mockCustomTemplateIDProvider) GetTemplateID(_ SessionID) *SessionID { + return &p.staticTemplateID +} + +func (suite *AcceptorTemplateTestSuite) TestCustomTemplateIDProvider_NoSessionCreated() { + // this templateIDProvider selects session FIX.4.3:sender2->sender2 as the template + templateIDProvider := &mockCustomTemplateIDProvider{staticTemplateID: SessionID{ + BeginString: BeginStringFIX43, SenderCompID: "sender2", TargetCompID: "target2", + }} + suite.acceptor.SetTemplateIDProvider(templateIDProvider) + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + // no session created + logon1 := SessionID{BeginString: BeginStringFIX42, SenderCompID: "target1", TargetCompID: "sender1"} + suite.logonAndDisconnectAfterCheck(logon1, func() { + suite.verifySessionCount(2) + }, true) +} + +func (suite *AcceptorTemplateTestSuite) TestCustomTemplateIDProvider_SessionCreated() { + // this templateIDProvider selects session FIX.4.3:sender2->sender2 as the template + templateIDProvider := &mockCustomTemplateIDProvider{staticTemplateID: SessionID{ + BeginString: BeginStringFIX43, SenderCompID: "sender2", TargetCompID: "target2", + }} + suite.acceptor.SetTemplateIDProvider(templateIDProvider) + if err := suite.acceptor.Start(); err != nil { + suite.FailNow("acceptor start failed", err) + } + suite.verifySessionCount(2) + + // session created + logonSessionID2 := SessionID{BeginString: BeginStringFIX42, SenderCompID: "any", TargetCompID: "any"} + suite.logonAndDisconnectAfterCheck(logonSessionID2, func() { + suite.verifySessionCount(3) + }, true) + // logon again + suite.logonAndDisconnectAfterCheck(logonSessionID2, func() { + suite.verifySessionCount(3) + }, true) + + session2 := suite.acceptor.sessions[logonSessionID2] + suite.Require().Equal(true, session2.ResetOnLogon, "expected session2 ResetOnLogon=Y") +} + +func TestAcceptorTemplateTestSuite(t *testing.T) { + suite.Run(t, new(AcceptorTemplateTestSuite)) +} diff --git a/config/configuration.go b/config/configuration.go index 754f587b1..881ea8794 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -697,6 +697,9 @@ const ( // - Y // - N DynamicQualifier string = "DynamicQualifier" + + // AcceptorTemplate designates a template Acceptor session. + AcceptorTemplate string = "AcceptorTemplate" ) const ( diff --git a/store/file/filestore.go b/store/file/filestore.go index ae44540b1..f5027e300 100644 --- a/store/file/filestore.go +++ b/store/file/filestore.go @@ -38,6 +38,12 @@ type msgDef struct { type fileStoreFactory struct { settings *quickfix.Settings + + templateIDProvider quickfix.TemplateIDProvider +} + +func (f *fileStoreFactory) SetTemplateIDProvider(templateIDProvider quickfix.TemplateIDProvider) { + f.templateIDProvider = templateIDProvider } type fileStore struct { @@ -69,10 +75,20 @@ func (f fileStoreFactory) Create(sessionID quickfix.SessionID) (msgStore quickfi sessionSettings, ok := f.settings.SessionSettings()[sessionID] if !ok { - if dynamicSessions { - sessionSettings = globalSettings + if f.templateIDProvider != nil { + templateID := f.templateIDProvider.GetTemplateID(sessionID) + if templateID != nil { + sessionSettings, ok = f.settings.SessionSettings()[*templateID] + if !ok { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } + } } else { - return nil, fmt.Errorf("unknown session: %v", sessionID) + if dynamicSessions { + sessionSettings = globalSettings + } else { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } } } diff --git a/store/mongo/mongostore.go b/store/mongo/mongostore.go index 42696388e..faaa19a05 100644 --- a/store/mongo/mongostore.go +++ b/store/mongo/mongostore.go @@ -33,6 +33,12 @@ type mongoStoreFactory struct { settings *quickfix.Settings messagesCollection string sessionsCollection string + + templateIDProvider quickfix.TemplateIDProvider +} + +func (f *mongoStoreFactory) SetTemplateIDProvider(templateIDProvider quickfix.TemplateIDProvider) { + f.templateIDProvider = templateIDProvider } type mongoStore struct { @@ -67,10 +73,20 @@ func (f mongoStoreFactory) Create(sessionID quickfix.SessionID) (msgStore quickf sessionSettings, ok := f.settings.SessionSettings()[sessionID] if !ok { - if dynamicSessions { - sessionSettings = globalSettings + if f.templateIDProvider != nil { + templateID := f.templateIDProvider.GetTemplateID(sessionID) + if templateID != nil { + sessionSettings, ok = f.settings.SessionSettings()[*templateID] + if !ok { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } + } } else { - return nil, fmt.Errorf("unknown session: %v", sessionID) + if dynamicSessions { + sessionSettings = globalSettings + } else { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } } } mongoConnectionURL, err := sessionSettings.Setting(config.MongoStoreConnection) diff --git a/store/sql/sqlstore.go b/store/sql/sqlstore.go index aeaeb6eb3..bea643d97 100644 --- a/store/sql/sqlstore.go +++ b/store/sql/sqlstore.go @@ -29,6 +29,12 @@ import ( type sqlStoreFactory struct { settings *quickfix.Settings + + templateIDProvider quickfix.TemplateIDProvider +} + +func (f *sqlStoreFactory) SetTemplateIDProvider(templateIDProvider quickfix.TemplateIDProvider) { + f.templateIDProvider = templateIDProvider } type sqlStore struct { @@ -73,10 +79,20 @@ func (f sqlStoreFactory) Create(sessionID quickfix.SessionID) (msgStore quickfix sessionSettings, ok := f.settings.SessionSettings()[sessionID] if !ok { - if dynamicSessions { - sessionSettings = globalSettings + if f.templateIDProvider != nil { + templateID := f.templateIDProvider.GetTemplateID(sessionID) + if templateID != nil { + sessionSettings, ok = f.settings.SessionSettings()[*templateID] + if !ok { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } + } } else { - return nil, fmt.Errorf("unknown session: %v", sessionID) + if dynamicSessions { + sessionSettings = globalSettings + } else { + return nil, fmt.Errorf("unknown session: %v", sessionID) + } } }