diff --git a/cassandra_test.go b/cassandra_test.go index 54a54f426..beccbd1ee 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -44,7 +44,7 @@ import ( "time" "unicode" - inf "gopkg.in/inf.v0" + "gopkg.in/inf.v0" "github.com/stretchr/testify/require" ) @@ -3955,3 +3955,46 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() } + +func TestTimeoutOverride(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < 3 { + t.Skip("named Values are not supported in protocol < 3") + } + + if err := createTable(session, "CREATE TABLE gocql_test.named_query(id int, value text, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + // normal case + err := session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(1, 'value')").Exec() + if err != nil { + t.Fatal(err) + } + + //decrease Conn.timeout + session.executor.pool.mu.Lock() + for _, conPool := range session.executor.pool.hostConnPools { + conPool.mu.Lock() + for _, conn := range conPool.conns { + conn.r.SetTimeout(50) + } + conPool.mu.Unlock() + } + session.executor.pool.mu.Unlock() + err = session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(2, 'value')").Exec() + if err != ErrTimeoutNoResponse { + t.Fatalf("expected: ErrTimeoutNoResponse, got: %v", err) + } + + // override timeout with context + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = session.Query("TRUNCATE TABLE gocql_test.named_query").WithContext(ctx).Exec() + if err != nil { + t.Fatal(err) + } + +} diff --git a/conn.go b/conn.go index d2f83d742..423850167 100644 --- a/conn.go +++ b/conn.go @@ -1179,6 +1179,9 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer, startupCompleted bool) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { + if ctxErr == context.DeadlineExceeded { + c.handleTimeout() + } return nil, ctxErr } @@ -1273,7 +1276,8 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer } var timeoutCh <-chan time.Time - if timeout := c.r.GetTimeout(); timeout > 0 { + _, isDeadline := ctx.Deadline() + if timeout := c.r.GetTimeout(); timeout > 0 && !isDeadline { if call.timer == nil { call.timer = time.NewTimer(0) <-call.timer.C @@ -1326,6 +1330,9 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer c.handleTimeout() return nil, ErrTimeoutNoResponse case <-ctxDone: + if ctx.Err() == context.DeadlineExceeded { + c.handleTimeout() + } close(call.timeout) return nil, ctx.Err() case <-c.ctx.Done(): diff --git a/session.go b/session.go index ed1a078d3..6e4cac99b 100644 --- a/session.go +++ b/session.go @@ -1132,6 +1132,13 @@ func (q *Query) withContext(ctx context.Context) ExecutableQuery { // The provided context controls the entire lifetime of executing a // query, queries will be canceled and return once the context is // canceled. +// +// You can set context.WithTimeout to override default timeout with custom (per query): +// Example: +// +// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) +// defer cancel() +// err = session.Query("SELECT * FROM my_table").WithContext(ctx).Exec() func (q *Query) WithContext(ctx context.Context) *Query { q2 := *q q2.context = ctx