diff --git a/client_impl.go b/client_impl.go index 074244d..a0766b0 100644 --- a/client_impl.go +++ b/client_impl.go @@ -528,18 +528,30 @@ func (c *runnerImpl) executeTransaction( // - the query is a write query AccessMode: neo4j.AccessModeRead, } + c.ensureCausalConsistency(ctx, &sessConfig) if sess == nil { if conf := c.execConfig.SessionConfig; conf != nil { sessConfig = *conf } if cy.IsWrite || sessConfig.AccessMode == neo4j.AccessModeWrite { sessConfig.AccessMode = neo4j.AccessModeWrite - sess = c.db.NewSession(ctx, sessConfig) } else { sessConfig.AccessMode = neo4j.AccessModeRead - sess = c.db.NewSession(ctx, sessConfig) } + sess = c.db.NewSession(ctx, sessConfig) defer func() { + if sessConfig.AccessMode == neo4j.AccessModeWrite { + bookmarks := sess.LastBookmarks() + key := c.causalConsistencyKey(ctx) + if cur, ok := causalConsistencyCache[key]; ok { + causalConsistencyCache[key] = neo4j.CombineBookmarks(cur, bookmarks) + } else { + causalConsistencyCache[key] = bookmarks + go func(key string) { + causalConsistencyCache[key] = nil + }(key) + } + } if closeErr := sess.Close(ctx); closeErr != nil { err = errors.Join(err, closeErr) } diff --git a/driver.go b/driver.go index 02a946f..1c47776 100644 --- a/driver.go +++ b/driver.go @@ -101,9 +101,11 @@ type ( type ( driver struct { registry - db neo4j.DriverWithContext + db neo4j.DriverWithContext + causalConsistencyKey func(ctx context.Context) string } session struct { + *driver registry db neo4j.DriverWithContext execConfig execConfig @@ -116,6 +118,14 @@ type ( } ) +var causalConsistencyCache map[string]neo4j.Bookmarks + +func WithCausalConsistency(when func(ctx context.Context) string) Config { + return func(d *driver) { + d.causalConsistencyKey = when + } +} + // WithTxConfig configures the transaction used by Exec(). func WithTxConfig(configurers ...func(*neo4j.TransactionConfig)) func(ec *execConfig) { return func(ec *execConfig) { @@ -160,12 +170,28 @@ func (d *driver) Exec(configurers ...func(*execConfig)) Query { return session.newClient(internal.NewCypherClient()) } +func (d *driver) ensureCausalConsistency(ctx context.Context, sc *neo4j.SessionConfig) { + if d.causalConsistencyKey == nil { + return + } + var key string + if key = d.causalConsistencyKey(ctx); key == "" { + return + } + bookmarks := causalConsistencyCache[key] + if bookmarks == nil { + return + } + sc.Bookmarks = bookmarks +} + func (d *driver) ReadSession(ctx context.Context, configurers ...func(*neo4j.SessionConfig)) readSession { config := neo4j.SessionConfig{} for _, c := range configurers { c(&config) } config.AccessMode = neo4j.AccessModeRead + d.ensureCausalConsistency(ctx, &config) sess := d.db.NewSession(ctx, config) return &session{ registry: d.registry, @@ -180,6 +206,7 @@ func (d *driver) WriteSession(ctx context.Context, configurers ...func(*neo4j.Se c(&config) } config.AccessMode = neo4j.AccessModeWrite + d.ensureCausalConsistency(ctx, &config) sess := d.db.NewSession(ctx, config) return &session{ registry: d.registry,