Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -3955,3 +3955,46 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {

session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec()
}

func TestTimeoutOverride(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < 3 {
t.Skip("named Values are not supported in protocol < 3")
}

if err := createTable(session, "CREATE TABLE gocql_test.named_query(id int, value text, PRIMARY KEY (id))"); err != nil {
t.Fatal(err)
}

// normal case
err := session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(1, 'value')").Exec()
if err != nil {
t.Fatal(err)
}

//decrease Conn.timeout
session.executor.pool.mu.Lock()
for _, conPool := range session.executor.pool.hostConnPools {
conPool.mu.Lock()
for _, conn := range conPool.conns {
conn.r.SetTimeout(50)
}
conPool.mu.Unlock()
}
session.executor.pool.mu.Unlock()
err = session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(2, 'value')").Exec()
if err != ErrTimeoutNoResponse {
t.Fatalf("expected: ErrTimeoutNoResponse, got: %v", err)
}

// override timeout with context
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = session.Query("TRUNCATE TABLE gocql_test.named_query").WithContext(ctx).Exec()
if err != nil {
t.Fatal(err)
}

}
9 changes: 8 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,9 @@ func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*fram

func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer, startupCompleted bool) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
if ctxErr == context.DeadlineExceeded {
c.handleTimeout()
}
return nil, ctxErr
}

Expand Down Expand Up @@ -1273,7 +1276,8 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer
}

var timeoutCh <-chan time.Time
if timeout := c.r.GetTimeout(); timeout > 0 {
_, isDeadline := ctx.Deadline()
if timeout := c.r.GetTimeout(); timeout > 0 && !isDeadline {
if call.timer == nil {
call.timer = time.NewTimer(0)
<-call.timer.C
Expand Down Expand Up @@ -1326,6 +1330,9 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer
c.handleTimeout()
return nil, ErrTimeoutNoResponse
case <-ctxDone:
if ctx.Err() == context.DeadlineExceeded {
c.handleTimeout()
}
close(call.timeout)
return nil, ctx.Err()
case <-c.ctx.Done():
Expand Down
7 changes: 7 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,13 @@ func (q *Query) withContext(ctx context.Context) ExecutableQuery {
// The provided context controls the entire lifetime of executing a
// query, queries will be canceled and return once the context is
// canceled.
//
// You can set context.WithTimeout to override default timeout with custom (per query):
// Example:
//
// ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
// defer cancel()
// err = session.Query("SELECT * FROM my_table").WithContext(ctx).Exec()
func (q *Query) WithContext(ctx context.Context) *Query {
q2 := *q
q2.context = ctx
Expand Down
Loading