From 7d3723acd65f9e239ed326b79b1701baa4906144 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 14 Nov 2024 13:47:27 +0530 Subject: [PATCH] wip --- .../vtgate/planbuilder/testdata/onecase.json | 23 ++++- go/vt/vtgate/semantics/analyzer.go | 10 +++ go/vt/vtgate/semantics/cte_table.go | 5 ++ go/vt/vtgate/semantics/derived_table.go | 85 +++++++++++++++++-- go/vt/vtgate/semantics/real_table.go | 10 +++ go/vt/vtgate/semantics/semantic_table.go | 34 ++++++++ go/vt/vtgate/semantics/table_collector.go | 7 ++ go/vt/vtgate/semantics/vindex_table.go | 5 ++ go/vt/vtgate/semantics/vtable.go | 5 ++ 9 files changed, 177 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index 9d653b2f6e9..97947b846a7 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,8 +1,29 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "SELECT music_id FROM music_extra WHERE music_id IN (select * from (select music.id from music where music.user_id = 1234 AND music.foo = 'bar' union select music.id from music where music.user_id = 1234 AND music.foo = 'baz') as subquery)", "plan": { + "QueryType": "SELECT", + "Original": "SELECT music_id FROM music_extra WHERE music_id IN (select music.id from music where music.user_id = 1234 AND music.foo = 'bar' union select music.id from music where music.user_id = 1234 AND music.foo = 'baz')", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music_id from music_extra where 1 != 1", + "Query": "select music_id from music_extra where music_id in (select music.id from music where music.user_id = 1234 and music.foo = 'bar' union select music.id from music where music.user_id = 1234 and music.foo = 'baz')", + "Table": "music_extra", + "Values": [ + "1234" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "user.music", + "user.music_extra" + ] } } ] \ No newline at end of file diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 0a9d2480d9b..0e2da4e01ba 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -343,6 +343,7 @@ type originable interface { tableSetFor(t *sqlparser.AliasedTableExpr) TableSet depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) collationEnv() *collations.Environment + tableInfoFor(ts TableSet) (TableInfo, error) } func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) { @@ -352,6 +353,15 @@ func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, return } +// tableInfoFor returns the table info for the table set. It should contains only single table. +func (a *analyzer) tableInfoFor(id TableSet) (TableInfo, error) { + offset := id.TableOffset() + if offset < 0 { + return nil, ErrNotSingleTable + } + return a.tables.Tables[offset], nil +} + func (a *analyzer) collationEnv() *collations.Environment { return a.typer.collationEnv } diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 498fc5076c1..aaa5262524e 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -71,6 +71,11 @@ func (cte *CTETable) GetVindexTable() *vindexes.Table { return nil } +// getColumnVindexForColumn implements the TableInfo interface +func (cte *CTETable) getColumnVindexForColumn(org originable, columnName string) (vList []*vindexes.ColumnVindex, err error) { + return nil, nil +} + func (cte *CTETable) IsInfSchema() bool { return false } diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index fc7e1cb391c..0e49136001b 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -31,21 +31,78 @@ type DerivedTable struct { tableName string ASTNode *sqlparser.AliasedTableExpr columnNames []string - cols []sqlparser.Expr tables TableSet isAuthoritative bool + selectExprColumns +} + +type selectColumns struct { + cols []sqlparser.Expr recursive []TableSet types []evalengine.Type } -type unionInfo struct { - isAuthoritative bool - recursive []TableSet - types []evalengine.Type - exprs sqlparser.SelectExprs +func (s *selectColumns) getVindexForIndex(org originable, idx int) ([]*vindexes.ColumnVindex, error) { + expr := s.cols[idx] + colName, isCol := expr.(*sqlparser.ColName) + if !isCol { + return nil, nil + } + direct, _, _ := org.depsForExpr(colName) + ti, err := org.tableInfoFor(direct) + if err != nil { + return nil, err + } + vindexTable := ti.GetVindexTable() + var result []*vindexes.ColumnVindex + for _, vindex := range vindexTable.ColumnVindexes { + if vindex.Columns[0].Equal(colName.Name) { + result = append(result, vindex) + } + } + return result, nil } +func intersect(a, b []*vindexes.ColumnVindex) []*vindexes.ColumnVindex { + var result []*vindexes.ColumnVindex + for _, vindex := range a { + for _, vindex2 := range b { + if vindex == vindex2 { + result = append(result, vindex) + } + } + } + return result +} + +func (s *unionColumns) getVindexForIndex(org originable, offset int) ([]*vindexes.ColumnVindex, error) { + var vlist []*vindexes.ColumnVindex + for idx, columns := range s.sc { + these, err := columns.getVindexForIndex(org, offset) + if err != nil { + return nil, err + } + if idx == 0 { + vlist = these + continue + } + vlist = intersect(vlist, these) + } + return vlist, nil +} + +type unionColumns struct { + sc []selectColumns +} + +type selectExprColumns interface { + getVindexForIndex(org originable, idx int) ([]*vindexes.ColumnVindex, error) +} + +var _ selectExprColumns = (*selectColumns)(nil) +var _ selectExprColumns = (*unionColumns)(nil) + var _ TableInfo = (*DerivedTable)(nil) func createDerivedTableForExpressions( @@ -154,6 +211,22 @@ func (dt *DerivedTable) GetVindexTable() *vindexes.Table { return nil } +// getColumnVindexForColumn implements the TableInfo interface +func (dt *DerivedTable) getColumnVindexForColumn(org originable, columnName string) (vList []*vindexes.ColumnVindex, err error) { + offset := -1 + for idx, name := range dt.columnNames { + if name == columnName { + offset = idx + break + } + } + if offset == -1 { + return nil, nil + } + return dt.selectExprColumns.getVindexForIndex(org, offset) + +} + func (dt *DerivedTable) getColumns(bool) []ColumnInfo { cols := make([]ColumnInfo, 0, len(dt.columnNames)) for _, col := range dt.columnNames { diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 64f3ac5f3f0..ecab78d0ea5 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -211,6 +211,16 @@ func (r *RealTable) GetVindexTable() *vindexes.Table { return r.Table } +// getColumnVindexForColumn implements the TableInfo interface +func (r *RealTable) getColumnVindexForColumn(org originable, columnName string) (vList []*vindexes.ColumnVindex, err error) { + for _, vindex := range r.Table.ColumnVindexes { + if len(vindex.Columns) == 1 && vindex.Columns[0].EqualString(columnName) { + vList = append(vList, vindex) + } + } + return vList, nil +} + // GetVindexHint implements the TableInfo interface func (r *RealTable) GetVindexHint() *sqlparser.IndexHint { return r.VindexHint diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index f9856a901a6..b51b702fc46 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -64,6 +64,9 @@ type ( getExprFor(s string) (sqlparser.Expr, error) getTableSet(org originable) TableSet + // getColumnVindexForColumn gets the vindex for the given column if available. + getColumnVindexForColumn(org originable, columnName string) ([]*vindexes.ColumnVindex, error) + // GetMirrorRule returns the vschema mirror rule for this TableInfo GetMirrorRule() *vindexes.MirrorRule } @@ -200,6 +203,37 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) { } } +var _ originable = (*SemTable)(nil) + +func (st *SemTable) tableSetFor(t *sqlparser.AliasedTableExpr) TableSet { + return st.TableSetFor(t) +} + +func (st *SemTable) depsForExpr(expr sqlparser.Expr) (TableSet, TableSet, evalengine.Type) { + typ, _ := st.TypeForExpr(expr) + return st.DirectDeps(expr), st.RecursiveDeps(expr), typ +} + +func (st *SemTable) collationEnv() *collations.Environment { + return st.collEnv +} + +func (st *SemTable) tableInfoFor(ts TableSet) (TableInfo, error) { + return st.TableInfoFor(ts) +} + +func (st *SemTable) GetVindexesForExpr(expr sqlparser.Expr) ([]*vindexes.ColumnVindex, error) { + colName, isCol := expr.(*sqlparser.ColName) + if !isCol { + return nil, nil + } + ti, err := st.TableInfoForExpr(colName) + if err != nil { + return nil, err + } + return ti.getColumnVindexForColumn(st, colName.Name.String()) +} + // GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables. func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) { for _, ts := range st.Targets.Constituents() { diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 191d9c3b38e..286124df681 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -47,6 +47,13 @@ type ( // cte is a map of CTE definitions that are used in the query cte map[string]*CTE } + + unionInfo struct { + isAuthoritative bool + recursive []TableSet + types []evalengine.Type + exprs sqlparser.SelectExprs + } ) func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { diff --git a/go/vt/vtgate/semantics/vindex_table.go b/go/vt/vtgate/semantics/vindex_table.go index c8ef271af5d..2fedddd010f 100644 --- a/go/vt/vtgate/semantics/vindex_table.go +++ b/go/vt/vtgate/semantics/vindex_table.go @@ -51,6 +51,11 @@ func (v *VindexTable) GetVindexTable() *vindexes.Table { return v.Table.GetVindexTable() } +// getColumnVindexForColumn implements the TableInfo interface +func (v *VindexTable) getColumnVindexForColumn(org originable, columnName string) (vList []*vindexes.ColumnVindex, err error) { + return nil, nil +} + // Matches implements the TableInfo interface func (v *VindexTable) matches(name sqlparser.TableName) bool { return v.Table.matches(name) diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index 6cd7e34aecc..2886de46faf 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -104,6 +104,11 @@ func (v *vTableInfo) GetVindexTable() *vindexes.Table { return nil } +// getColumnVindexForColumn implements the TableInfo interface +func (v *vTableInfo) getColumnVindexForColumn(org originable, columnName string) (vList []*vindexes.ColumnVindex, err error) { + return nil, nil +} + func (v *vTableInfo) getColumns(bool) []ColumnInfo { cols := make([]ColumnInfo, 0, len(v.columnNames)) for _, col := range v.columnNames {