Skip to content

Commit

Permalink
implement driver for database/sql
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepymole committed Mar 26, 2023
1 parent 23959a5 commit 29aa018
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 33 deletions.
69 changes: 38 additions & 31 deletions cmd/zgraph/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"context"
"database/sql"
"errors"
"fmt"
"io"
Expand All @@ -26,8 +27,7 @@ import (
"github.com/jedib0t/go-pretty/v6/text"
"github.com/knz/bubbline"
"github.com/spf13/cobra"
"github.com/vescale/zgraph"
"github.com/vescale/zgraph/session"
_ "github.com/vescale/zgraph"
)

type options struct {
Expand All @@ -39,15 +39,19 @@ func main() {
cmd := cobra.Command{
Use: "zgraph --data <dirname>",
RunE: func(cmd *cobra.Command, args []string) error {
db, err := zgraph.Open(opt.dataDir, nil)
db, err := sql.Open("zgraph", opt.dataDir)
if err != nil {
return err
}
defer db.Close()

session := db.NewSession()
defer session.Close()
conn, err := db.Conn(context.Background())
if err != nil {
return err
}
defer conn.Close()

interact(session)
interact(conn)

return nil
},
Expand All @@ -60,7 +64,7 @@ func main() {
cobra.CheckErr(err)
}

func interact(session *session.Session) {
func interact(conn *sql.Conn) {
fmt.Println("Welcome to zGraph interactive command line.")

m := bubbline.New()
Expand Down Expand Up @@ -103,19 +107,21 @@ func interact(session *session.Session) {
if stmt == "" {
continue
}
runQuery(session, stmt)
runQuery(conn, stmt)
}
lastStmt = stmts[len(stmts)-1]
}
}

func runQuery(session *session.Session, query string) {
rs, err := session.Execute(context.Background(), query)
func runQuery(conn *sql.Conn, query string) {
rows, err := conn.QueryContext(context.Background(), query)
if err != nil {
outputError(err)
return
}
output, err := render(rs)
defer rows.Close()

output, err := render(rows)
if err != nil {
outputError(err)
return
Expand All @@ -125,46 +131,47 @@ func runQuery(session *session.Session, query string) {
}
}

func render(rs session.ResultSet) (string, error) {
defer rs.Close()
func render(rows *sql.Rows) (string, error) {
w := table.NewWriter()
w.Style().Format = table.FormatOptions{
Footer: text.FormatDefault,
Header: text.FormatDefault,
Row: text.FormatDefault,
}

if len(rs.Fields()) > 0 {
cols, err := rows.Columns()
if err != nil {
return "", err
}

if len(cols) > 0 {
var header []any
for _, field := range rs.Fields() {
header = append(header, field.Name)
for _, col := range cols {
header = append(header, col)
}
w.AppendHeader(header)
}

fields := make([]any, 0, len(rs.Fields()))
for range rs.Fields() {
var s string
fields = append(fields, &s)
dest := make([]any, len(cols))
for i := range cols {
var anyStr sql.NullString
dest[i] = &anyStr
}

for {
if err := rs.Next(context.Background()); err != nil {
return "", err
}
if !rs.Valid() {
break
}
if err := rs.Scan(fields...); err != nil {
for rows.Next() {
if err := rows.Scan(dest...); err != nil {
return "", err
}
var row []any
for _, f := range fields {
row = append(row, *f.(*string))
for _, d := range dest {
anyStr := *d.(*sql.NullString)
row = append(row, anyStr.String)
}
w.AppendRow(row)
}

if rows.Err() != nil {
return "", rows.Err()
}
return w.Render(), nil
}

Expand Down
4 changes: 2 additions & 2 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (tk *TestKit) MustExec(ctx context.Context, query string) {
require.NoError(tk.t, rs.Next(ctx))
}

func TestDB_DDL(t *testing.T) {
func TestDBDDL(t *testing.T) {
db, err := Open(t.TempDir(), nil)
require.NoError(t, err)
require.NotNil(t, db)
Expand Down Expand Up @@ -86,7 +86,7 @@ func TestDB_DDL(t *testing.T) {
require.Nil(t, catalog.Graph("graph101"))
}

func TestDB_Select(t *testing.T) {
func TestDBSelect(t *testing.T) {
db, err := Open(t.TempDir(), nil)
require.NoError(t, err)
require.NotNil(t, db)
Expand Down
186 changes: 186 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright 2023 zGraph Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package zgraph

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"

"github.com/vescale/zgraph/session"
)

const driverName = "zgraph"

func init() {
sql.Register(driverName, &Driver{})
}

var (
_ driver.Driver = &Driver{}
_ driver.DriverContext = &Driver{}
_ driver.Connector = &connector{}
_ io.Closer = &connector{}
_ driver.Conn = &conn{}
_ driver.Stmt = &stmt{}
_ driver.StmtExecContext = &stmt{}
_ driver.StmtQueryContext = &stmt{}
_ driver.Rows = &rows{}
)

type Driver struct{}

func (d *Driver) Open(_ string) (driver.Conn, error) {
return nil, errors.New("Driver.Open should not be called as Driver.OpenConnector is implemented")
}

func (d *Driver) OpenConnector(dsn string) (driver.Connector, error) {
db, err := Open(dsn, nil)
if err != nil {
return nil, err
}
return &connector{db: db}, nil
}

type connector struct {
db *DB
}

func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
return &conn{session: c.db.NewSession()}, nil
}

func (c *connector) Driver() driver.Driver {
return &Driver{}
}

func (c *connector) Close() error {
return c.db.Close()
}

type conn struct {
session *session.Session
}

func (c *conn) Ping(_ context.Context) error {
return nil
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}

func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return &stmt{session: c.session, query: query}, nil
}

func (c *conn) Close() error {
c.session.Close()
return nil
}

func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}

func (c *conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) {
return nil, errors.New("transactions are not supported")
}

type stmt struct {
session *session.Session
query string
}

func (s *stmt) Close() error {
return nil
}

func (s *stmt) NumInput() int {
return -1
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
if len(args) > 0 {
return nil, fmt.Errorf("placeholder arguments not supported")
}
return s.ExecContext(context.Background(), nil)
}

func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if len(args) > 0 {
return nil, fmt.Errorf("placeholder arguments not supported")
}
rs, err := s.session.Execute(ctx, s.query)
if err != nil {
return nil, err
}
if err := rs.Next(ctx); err != nil {
return nil, err
}
return driver.ResultNoRows, nil
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
if len(args) > 0 {
return nil, fmt.Errorf("placeholder arguments not supported")
}
return s.QueryContext(context.Background(), nil)
}

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if len(args) > 0 {
return nil, fmt.Errorf("placeholder arguments not supported")
}
rs, err := s.session.Execute(ctx, s.query)
if err != nil {
return nil, err
}
return &rows{ctx: ctx, rs: rs}, nil
}

type rows struct {
ctx context.Context
rs session.ResultSet
}

func (r *rows) Columns() []string {
cols := make([]string, 0, len(r.rs.Fields()))
for _, f := range r.rs.Fields() {
cols = append(cols, f.Name)
}
return cols
}

func (r *rows) Close() error {
return r.rs.Close()
}

func (r *rows) Next(dest []driver.Value) error {
if err := r.rs.Next(r.ctx); err != nil {
return err
}
if !r.rs.Valid() {
return io.EOF
}
var destPtrs []any
for i := range dest {
destPtrs = append(destPtrs, &dest[i])
}
return r.rs.Scan(destPtrs...)
}
46 changes: 46 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 zGraph Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package zgraph_test

import (
"context"
"database/sql"
"testing"

"github.com/samber/lo"
"github.com/stretchr/testify/require"
)

func TestDriver(t *testing.T) {
db, err := sql.Open("zgraph", t.TempDir())
require.NoError(t, err)
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

conn := lo.Must1(db.Conn(ctx))
_ = lo.Must1(conn.ExecContext(ctx, "CREATE GRAPH g"))
_ = lo.Must1(conn.ExecContext(ctx, "USE g"))
_ = lo.Must1(conn.ExecContext(ctx, "INSERT VERTEX x PROPERTIES (x.a = 123)"))
rows := lo.Must1(conn.QueryContext(ctx, "SELECT x.a FROM MATCH (x)"))

var a int
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&a))
require.Equal(t, 123, a)
require.False(t, rows.Next())
require.NoError(t, rows.Err())
}
Loading

0 comments on commit 29aa018

Please sign in to comment.