Skip to content

Commit 4ad60b6

Browse files
committed
clean up rebase
1 parent 0da5a3c commit 4ad60b6

File tree

5 files changed

+96
-45
lines changed

5 files changed

+96
-45
lines changed

cassandra_test.go

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3669,7 +3669,9 @@ func TestQueryCompressionNotWorthIt(t *testing.T) {
36693669
// The driver should handle this by updating its prepared statement inside the cache
36703670
// when it receives RESULT/ROWS with Metadata_changed flag
36713671
func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
3672-
session := createSession(t)
3672+
session := createSession(t, func(config *ClusterConfig) {
3673+
config.NumConns = 1
3674+
})
36733675
defer session.Close()
36743676

36753677
if session.cfg.ProtoVersion < protoVersion5 {
@@ -3693,13 +3695,17 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
36933695
t.Fatal(err)
36943696
}
36953697

3696-
// We have to specify conn for all queries to ensure that
3698+
// We have to specify host for all queries to ensure that
36973699
// all queries are running on the same node
3698-
conn := session.getConn()
3700+
hosts := session.GetHosts()
3701+
if len(hosts) == 0 {
3702+
t.Fatal("no hosts found")
3703+
}
3704+
hostid := hosts[0].HostID()
36993705

37003706
const selectStmt = "SELECT * FROM gocql_test.metadata_changed"
37013707
queryBeforeTableAltering := session.Query(selectStmt)
3702-
queryBeforeTableAltering.conn = conn
3708+
queryBeforeTableAltering.SetHostID(hostid)
37033709
row := make(map[string]interface{})
37043710
err = queryBeforeTableAltering.MapScan(row)
37053711
if err != nil {
@@ -3709,13 +3715,16 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
37093715
require.Len(t, row, 1, "Expected to retrieve a single column")
37103716
require.Equal(t, 1, row["id"])
37113717

3712-
stmtCacheKey := session.stmtsLRU.keyFor(conn.host.HostID(), conn.currentKeyspace, queryBeforeTableAltering.stmt)
3713-
inflight, _ := session.stmtsLRU.get(stmtCacheKey)
3718+
stmtCacheKey := session.stmtsLRU.keyFor(hostid, "", queryBeforeTableAltering.stmt)
3719+
inflight, ok := session.stmtsLRU.get(stmtCacheKey)
3720+
if !ok {
3721+
t.Fatalf("failed to find inflight entry for key %v", stmtCacheKey)
3722+
}
37143723
preparedStatementBeforeTableAltering := inflight.preparedStatment
37153724

37163725
// Changing table schema in order to cause C* to return RESULT/ROWS Metadata_changed
37173726
alteringTableQuery := session.Query("ALTER TABLE gocql_test.metadata_changed ADD new_col int")
3718-
alteringTableQuery.conn = conn
3727+
alteringTableQuery.SetHostID(hostid)
37193728
err = alteringTableQuery.Exec()
37203729
if err != nil {
37213730
t.Fatal(err)
@@ -3767,7 +3776,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
37673776
// Expecting C* will return RESULT/ROWS Metadata_changed
37683777
// and it will be properly handled
37693778
queryAfterTableAltering := session.Query(selectStmt)
3770-
queryAfterTableAltering.conn = conn
3779+
queryAfterTableAltering.SetHostID(hostid)
37713780
iter := queryAfterTableAltering.Iter()
37723781
handleRows(iter)
37733782

@@ -3792,7 +3801,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
37923801
defer cancel()
37933802

37943803
queryAfterTableAltering2 := session.Query(selectStmt).WithContext(ctx)
3795-
queryAfterTableAltering2.conn = conn
3804+
queryAfterTableAltering2.SetHostID(hostid)
37963805
iter = queryAfterTableAltering2.Iter()
37973806
handleRows(iter)
37983807
err = iter.Close()
@@ -3809,7 +3818,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
38093818
// Executing prepared stmt and expecting that C* won't return
38103819
// Metadata_changed because the table is not being changed.
38113820
queryAfterTableAltering3 := session.Query(selectStmt).WithContext(ctx)
3812-
queryAfterTableAltering3.conn = conn
3821+
queryAfterTableAltering3.SetHostID(hostid)
38133822
iter = queryAfterTableAltering2.Iter()
38143823
handleRows(iter)
38153824

@@ -3913,7 +3922,10 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
39133922
getRoutingKeyInfo := func(key string) *routingKeyInfo {
39143923
t.Helper()
39153924
session.routingKeyInfoCache.mu.Lock()
3916-
value, _ := session.routingKeyInfoCache.lru.Get(key)
3925+
value, ok := session.routingKeyInfoCache.lru.Get(key)
3926+
if !ok {
3927+
t.Fatalf("routing key not found in cache for key %v", key)
3928+
}
39173929
session.routingKeyInfoCache.mu.Unlock()
39183930

39193931
inflight := value.(*inflightCachedEntry)
@@ -3923,20 +3935,22 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
39233935
const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)"
39243936

39253937
// Running batch in default ks gocql_test
3926-
b1 := session.NewBatch(LoggedBatch)
3938+
b1 := session.Batch(LoggedBatch)
39273939
b1.Query(insertQuery, 1)
3928-
_, err = b1.GetRoutingKey()
3940+
internalB := newInternalBatch(b1, nil)
3941+
_, err = internalB.GetRoutingKey()
39293942
require.NoError(t, err)
39303943

39313944
// Ensuring that the cache contains the query with default ks
39323945
routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt)
39333946
require.Equal(t, "gocql_test", routingKeyInfo1.keyspace)
39343947

39353948
// Running batch in gocql_test_routing_key_cache ks
3936-
b2 := session.NewBatch(LoggedBatch)
3949+
b2 := session.Batch(LoggedBatch)
39373950
b2.SetKeyspace("gocql_test_routing_key_cache")
39383951
b2.Query(insertQuery, 2)
3939-
_, err = b2.GetRoutingKey()
3952+
internalB2 := newInternalBatch(b2, nil)
3953+
_, err = internalB2.GetRoutingKey()
39403954
require.NoError(t, err)
39413955

39423956
// Ensuring that the cache contains the query with gocql_test_routing_key_cache ks
@@ -3947,15 +3961,18 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
39473961

39483962
// Running query in default ks gocql_test
39493963
q1 := session.Query(selectStmt, 1)
3950-
_, err = q1.GetRoutingKey()
3964+
iter := q1.Iter()
3965+
err = iter.Close()
39513966
require.NoError(t, err)
3952-
require.Equal(t, "gocql_test", q1.routingInfo.keyspace)
3967+
require.Equal(t, "gocql_test", iter.Keyspace())
39533968

39543969
// Running query in gocql_test_routing_key_cache ks
39553970
q2 := session.Query(selectStmt, 1)
3956-
_, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey()
3971+
q2.SetKeyspace("gocql_test_routing_key_cache")
3972+
iter = q2.Iter()
3973+
err = iter.Close()
39573974
require.NoError(t, err)
3958-
require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace)
3975+
require.Equal(t, "gocql_test_routing_key_cache", iter.Keyspace())
39593976

39603977
session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec()
39613978
}

conn.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,7 @@ func (c *Conn) executeQuery(ctx context.Context, q *internalQuery) *Iter {
14951495
params := queryParams{
14961496
consistency: q.GetConsistency(),
14971497
}
1498-
iter := newIter(q.metrics, q.Keyspace(), q.Table())
1498+
iter := newIter(q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace)
14991499

15001500
// frame checks that it is not 0
15011501
params.serialConsistency = qryOpts.serialCons
@@ -1646,7 +1646,7 @@ func (c *Conn) executeQuery(ctx context.Context, q *internalQuery) *Iter {
16461646
iter.meta = info.response
16471647
iter.meta.pagingState = copyBytes(x.meta.pagingState)
16481648
} else {
1649-
iter = newErrIter(errors.New("gocql: did not receive metadata but prepared info is nil"), q.metrics, q.Keyspace(), q.Table())
1649+
iter = newErrIter(errors.New("gocql: did not receive metadata but prepared info is nil"), q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace)
16501650
iter.framer = framer
16511651
return iter
16521652
}
@@ -1749,10 +1749,10 @@ func (c *Conn) UseKeyspace(keyspace string) error {
17491749

17501750
func (c *Conn) executeBatch(ctx context.Context, b *internalBatch) *Iter {
17511751
if c.version == protoVersion1 {
1752-
return newErrIter(ErrUnsupported, b.metrics, b.Keyspace(), b.Table())
1752+
return newErrIter(ErrUnsupported, b.metrics, b.Keyspace(), b.routingInfo, nil)
17531753
}
17541754

1755-
iter := newIter(b.metrics, b.Keyspace(), b.Table())
1755+
iter := newIter(b.metrics, b.Keyspace(), b.routingInfo, nil)
17561756

17571757
n := len(b.batchOpts.entries)
17581758
req := &writeBatchFrame{

control.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
502502
return fn(ch)
503503
}
504504

505-
return newErrIter(errNoControl, newQueryMetrics(), "", "")
505+
return newErrIter(errNoControl, newQueryMetrics(), "", nil, nil)
506506
}
507507

508508
func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
@@ -537,7 +537,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
537537

538538
func (c *controlConn) awaitSchemaAgreement() error {
539539
return c.withConn(func(conn *Conn) *Iter {
540-
return newErrIter(conn.awaitSchemaAgreement(context.TODO()), newQueryMetrics(), "", "")
540+
return newErrIter(conn.awaitSchemaAgreement(context.TODO()), newQueryMetrics(), "", nil, nil)
541541
}).err
542542
}
543543

query_executor.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ type internalRequest interface {
6060
retryPolicy() RetryPolicy
6161
speculativeExecutionPolicy() SpeculativeExecutionPolicy
6262
getQueryMetrics() *queryMetrics
63+
getRoutingInfo() *queryRoutingInfo
64+
getKeyspaceFunc() func() string
6365
RetryableQuery
6466
ExecutableStatement
6567
}
@@ -89,7 +91,7 @@ func (q *queryExecutor) speculate(ctx context.Context, qry internalRequest, sp S
8991
case <-ticker.C:
9092
go q.run(ctx, qry, hostIter, results)
9193
case <-ctx.Done():
92-
return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.Table())
94+
return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
9395
case iter := <-results:
9496
return iter
9597
}
@@ -157,7 +159,7 @@ func (q *queryExecutor) executeQuery(qry internalRequest) (*Iter, error) {
157159
case iter := <-results:
158160
return iter, nil
159161
case <-ctx.Done():
160-
return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.Table()), nil
162+
return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()), nil
161163
}
162164
}
163165

@@ -224,7 +226,7 @@ func (q *queryExecutor) do(ctx context.Context, qry internalRequest, hostIter Ne
224226
stopRetries = true
225227
default:
226228
// Undefined? Return nil and error, this will panic in the requester
227-
return newErrIter(ErrUnknownRetryType, qry.getQueryMetrics(), qry.Keyspace(), qry.Table())
229+
return newErrIter(ErrUnknownRetryType, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
228230
}
229231

230232
if stopRetries || attemptsReached {
@@ -236,10 +238,10 @@ func (q *queryExecutor) do(ctx context.Context, qry internalRequest, hostIter Ne
236238
}
237239

238240
if lastErr != nil {
239-
return newErrIter(lastErr, qry.getQueryMetrics(), qry.Keyspace(), qry.Table())
241+
return newErrIter(lastErr, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
240242
}
241243

242-
return newErrIter(ErrNoConnections, qry.getQueryMetrics(), qry.Keyspace(), qry.Table())
244+
return newErrIter(ErrNoConnections, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
243245
}
244246

245247
func (q *queryExecutor) run(ctx context.Context, qry internalRequest, hostIter NextHost, results chan<- *Iter) {
@@ -480,6 +482,14 @@ func (q *internalQuery) GetHostID() string {
480482
return q.qryOpts.hostID
481483
}
482484

485+
func (q *internalQuery) getRoutingInfo() *queryRoutingInfo {
486+
return q.routingInfo
487+
}
488+
489+
func (q *internalQuery) getKeyspaceFunc() func() string {
490+
return q.qryOpts.getKeyspace
491+
}
492+
483493
type batchOptions struct {
484494
trace Tracer
485495
observer BatchObserver
@@ -664,6 +674,14 @@ func (b *internalBatch) GetHostID() string {
664674
return ""
665675
}
666676

677+
func (b *internalBatch) getRoutingInfo() *queryRoutingInfo {
678+
return b.routingInfo
679+
}
680+
681+
func (b *internalBatch) getKeyspaceFunc() func() string {
682+
return nil
683+
}
684+
667685
func (b *internalBatch) execute(ctx context.Context, conn *Conn) *Iter {
668686
return conn.executeBatch(ctx, b)
669687
}

session.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ func (s *Session) AwaitSchemaAgreement(ctx context.Context) error {
374374
return errNoControl
375375
}
376376
return s.control.withConn(func(conn *Conn) *Iter {
377-
return newErrIter(conn.awaitSchemaAgreement(ctx), newQueryMetrics(), "", "")
377+
return newErrIter(conn.awaitSchemaAgreement(ctx), newQueryMetrics(), "", nil, nil)
378378
}).err
379379
}
380380

@@ -538,12 +538,12 @@ func (s *Session) initialized() bool {
538538
func (s *Session) executeQuery(qry *internalQuery) (it *Iter) {
539539
// fail fast
540540
if s.Closed() {
541-
return newErrIter(ErrSessionClosed, qry.metrics, qry.Keyspace(), qry.Table())
541+
return newErrIter(ErrSessionClosed, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
542542
}
543543

544544
iter, err := s.executor.executeQuery(qry)
545545
if err != nil {
546-
return newErrIter(err, qry.metrics, qry.Keyspace(), qry.Table())
546+
return newErrIter(err, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc())
547547
}
548548
if iter == nil {
549549
panic("nil iter")
@@ -764,19 +764,19 @@ func (s *Session) executeBatch(batch *Batch, ctx context.Context) *Iter {
764764
b := newInternalBatch(batch, ctx)
765765
// fail fast
766766
if s.Closed() {
767-
return newErrIter(ErrSessionClosed, b.metrics, b.Keyspace(), b.Table())
767+
return newErrIter(ErrSessionClosed, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc())
768768
}
769769

770770
// Prevent the execution of the batch if greater than the limit
771771
// Currently batches have a limit of 65536 queries.
772772
// https://datastax-oss.atlassian.net/browse/JAVA-229
773773
if batch.Size() > BatchSizeMaximum {
774-
return newErrIter(ErrTooManyStmts, b.metrics, b.Keyspace(), b.Table())
774+
return newErrIter(ErrTooManyStmts, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc())
775775
}
776776

777777
iter, err := s.executor.executeQuery(b)
778778
if err != nil {
779-
return newErrIter(err, b.metrics, b.Keyspace(), b.Table())
779+
return newErrIter(err, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc())
780780
}
781781

782782
return iter
@@ -1282,7 +1282,7 @@ func (q *Query) Iter() *Iter {
12821282
// over all results.
12831283
func (q *Query) IterContext(ctx context.Context) *Iter {
12841284
if isUseStatement(q.stmt) {
1285-
return newErrIter(ErrUseStmt, newQueryMetrics(), q.Keyspace(), "")
1285+
return newErrIter(ErrUseStmt, newQueryMetrics(), q.Keyspace(), nil, q.getKeyspace)
12861286
}
12871287

12881288
internalQry := newInternalQuery(q, ctx)
@@ -1418,21 +1418,22 @@ type Iter struct {
14181418
host *HostInfo
14191419
metrics *queryMetrics
14201420

1421-
keyspace string
1422-
table string
1421+
getKeyspace func() string
1422+
keyspace string
1423+
routingInfo *queryRoutingInfo
14231424

14241425
framer *framer
14251426
closed int32
14261427
}
14271428

1428-
func newErrIter(err error, metrics *queryMetrics, keyspace string, table string) *Iter {
1429-
iter := newIter(metrics, keyspace, table)
1429+
func newErrIter(err error, metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter {
1430+
iter := newIter(metrics, keyspace, routingInfo, getKeyspace)
14301431
iter.err = err
14311432
return iter
14321433
}
14331434

1434-
func newIter(metrics *queryMetrics, keyspace string, table string) *Iter {
1435-
return &Iter{metrics: metrics, keyspace: keyspace, table: table}
1435+
func newIter(metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter {
1436+
return &Iter{metrics: metrics, keyspace: keyspace, routingInfo: routingInfo, getKeyspace: getKeyspace}
14361437
}
14371438

14381439
// Host returns the host which the statement was sent to.
@@ -1456,11 +1457,26 @@ func (iter *Iter) Latency() int64 {
14561457
}
14571458

14581459
// Keyspace returns the keyspace the statement was executed against if the driver could determine it.
1459-
func (iter *Iter) Keyspace() string { return iter.keyspace }
1460+
func (iter *Iter) Keyspace() string {
1461+
if iter.getKeyspace != nil {
1462+
return iter.getKeyspace()
1463+
}
1464+
1465+
if iter.routingInfo != nil {
1466+
if ks := iter.routingInfo.getKeyspace(); ks != "" {
1467+
return ks
1468+
}
1469+
}
1470+
1471+
return iter.keyspace
1472+
}
14601473

14611474
// Table returns name of the table the statement was executed against if the driver could determine it.
14621475
func (iter *Iter) Table() string {
1463-
return iter.table
1476+
if iter.routingInfo != nil {
1477+
return iter.routingInfo.getTable()
1478+
}
1479+
return ""
14641480
}
14651481

14661482
type Scanner interface {

0 commit comments

Comments
 (0)