Skip to content

Commit

Permalink
Merge pull request #14 from supabase/da/avoid-state-mutation
Browse files Browse the repository at this point in the history
feat: avoid mutating shared state
  • Loading branch information
yusufpapurcu authored Jul 30, 2021
2 parents 37aa091 + 9028045 commit f432b53
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 95 deletions.
19 changes: 7 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client {
}

t := transport{
params: url.Values{},
header: http.Header{},
baseURL: *baseURL,
}
Expand Down Expand Up @@ -62,12 +61,10 @@ func (c *Client) ChangeSchema(schema string) *Client {
}

func (c *Client) From(table string) *QueryBuilder {
c.clientTransport.baseURL.Path += table
return &QueryBuilder{client: c}
return &QueryBuilder{client: c, tableName: table, headers: map[string]string{}, params: map[string]string{}}
}

func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {

// Get body if exist
var byteBody []byte = nil
if rpcBody != nil {
Expand All @@ -87,11 +84,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
}

if count != "" && (count == `exact` || count == `planned` || count == `estimated`) {
if c.clientTransport.header.Get("Prefer") == "" {
c.clientTransport.header.Set("Prefer", "count="+count)
} else {
c.clientTransport.header.Set("Prefer", c.clientTransport.header.Get("Prefer")+",count="+count)
}
req.Header.Add("Prefer", "count="+count)
}

resp, err := c.session.Do(req)
Expand All @@ -118,14 +111,16 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
}

type transport struct {
params url.Values
header http.Header
baseURL url.URL
}

func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header = t.header
for headerName, values := range t.header {
for _, val := range values {
req.Header.Add(headerName, val)
}
}
req.URL = t.baseURL.ResolveReference(req.URL)
req.URL.RawQuery = t.params.Encode()
return http.DefaultTransport.RoundTrip(req)
}
26 changes: 18 additions & 8 deletions execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"path"
)

// ExecuteError is the error response format from postgrest. We really
Expand All @@ -18,17 +19,26 @@ type ExecuteError struct {
Message string `json:"message"`
}

func executeHelper(client *Client, method string, body []byte) ([]byte, error) {
func executeHelper(client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, error) {
if client.ClientError != nil {
return nil, client.ClientError
}

readerBody := bytes.NewBuffer(body)
req, err := http.NewRequest(method, client.clientTransport.baseURL.Path, readerBody)
baseUrl := path.Join(append([]string{client.clientTransport.baseURL.Path}, urlFragments...)...)
req, err := http.NewRequest(method, baseUrl, readerBody)
if err != nil {
return nil, err
}

for key, val := range headers {
req.Header.Add(key, val)
}
q := req.URL.Query()
for key, val := range params {
q.Add(key, val)
}
req.URL.RawQuery = q.Encode()
resp, err := client.session.Do(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -56,17 +66,17 @@ func executeHelper(client *Client, method string, body []byte) ([]byte, error) {
return respbody, nil
}

func executeString(client *Client, method string, body []byte) (string, error) {
resp, err := executeHelper(client, method, body)
func executeString(client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) (string, error) {
resp, err := executeHelper(client, method, body, urlFragments, headers, params)
return string(resp), err
}

func execute(client *Client, method string, body []byte) ([]byte, error) {
return executeHelper(client, method, body)
func execute(client *Client, method string, body []byte, urlFragments []string, headers map[string]string, params map[string]string) ([]byte, error) {
return executeHelper(client, method, body, urlFragments, headers, params)
}

func executeTo(client *Client, method string, body []byte, to interface{}) error {
resp, err := executeHelper(client, method, body)
func executeTo(client *Client, method string, body []byte, to interface{}, urlFragments []string, headers map[string]string, params map[string]string) error {
resp, err := executeHelper(client, method, body, urlFragments, headers, params)

if err != nil {
return err
Expand Down
67 changes: 35 additions & 32 deletions filterbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@ import (
)

type FilterBuilder struct {
client *Client
method string
body []byte
client *Client
method string
body []byte
tableName string
headers map[string]string
params map[string]string
}

func (f *FilterBuilder) ExecuteString() (string, error) {
return executeString(f.client, f.method, f.body)
return executeString(f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
}

func (f *FilterBuilder) Execute() ([]byte, error) {
return execute(f.client, f.method, f.body)
return execute(f.client, f.method, f.body, []string{f.tableName}, f.headers, f.params)
}

func (f *FilterBuilder) ExecuteTo(to interface{}) error {
return executeTo(f.client, f.method, f.body, to)
return executeTo(f.client, f.method, f.body, to, []string{f.tableName}, f.headers, f.params)
}

var filterOperators = []string{"eq", "neq", "gt", "gte", "lt", "lte", "like", "ilike", "is", "in", "cs", "cd", "sl", "sr", "nxl", "nxr", "adj", "ov", "fts", "plfts", "phfts", "wfts"}
Expand All @@ -41,15 +44,15 @@ func (f *FilterBuilder) Filter(column, operator, value string) *FilterBuilder {
f.client.ClientError = fmt.Errorf("invalid filter operator")
return f
}
f.client.clientTransport.params.Add(column, operator+"."+value)
f.params[column] = fmt.Sprintf("%s.%s", operator, value)
return f
}

func (f *FilterBuilder) Or(filters, foreignTable string) *FilterBuilder {
if foreignTable != "" {
f.client.clientTransport.params.Add(foreignTable+".or", fmt.Sprintf("(%s)", filters))
f.params[foreignTable+".or"] = fmt.Sprintf("(%s)", filters)
} else {
f.client.clientTransport.params.Add("or", fmt.Sprintf("(%s)", filters))
f.params[foreignTable+"or"] = fmt.Sprintf("(%s)", filters)
}
return f
}
Expand All @@ -58,59 +61,59 @@ func (f *FilterBuilder) Not(column, operator, value string) *FilterBuilder {
if !isOperator(operator) {
return f
}
f.client.clientTransport.params.Add(column, "not."+operator+"."+value)
f.params[column] = fmt.Sprintf("not.%s.%s", operator, value)
return f
}

func (f *FilterBuilder) Match(userQuery map[string]string) *FilterBuilder {
for key, value := range userQuery {
f.client.clientTransport.params.Add(key, "eq."+value)
f.params[key] = "eq." + value
}
return f
}

func (f *FilterBuilder) Eq(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "eq."+value)
f.params[column] = "eq." + value
return f
}

func (f *FilterBuilder) Neq(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "neq."+value)
f.params[column] = "neq." + value
return f
}

func (f *FilterBuilder) Gt(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "gt."+value)
f.params[column] = "gt." + value
return f
}

func (f *FilterBuilder) Gte(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "gte."+value)
f.params[column] = "gte." + value
return f
}

func (f *FilterBuilder) Lt(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "lt."+value)
f.params[column] = "lt." + value
return f
}

func (f *FilterBuilder) Lte(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "lte."+value)
f.params[column] = "lte." + value
return f
}

func (f *FilterBuilder) Like(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "like."+value)
f.params[column] = "like." + value
return f
}

func (f *FilterBuilder) Ilike(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "ilike."+value)
f.params[column] = "ilike." + value
return f
}

func (f *FilterBuilder) Is(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "is."+value)
f.params[column] = "is." + value
return f
}

Expand All @@ -125,17 +128,17 @@ func (f *FilterBuilder) In(column string, values []string) *FilterBuilder {
cleanedValues = append(cleanedValues, value)
}
}
f.client.clientTransport.params.Add(column, fmt.Sprintf("in.(%s)", strings.Join(cleanedValues, ",")))
f.params[column] = fmt.Sprintf("in.(%s)", strings.Join(cleanedValues, ","))
return f
}

func (f *FilterBuilder) Contains(column string, value []string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "cs."+strings.Join(value, ","))
f.params[column] = "cs." + strings.Join(value, ",")
return f
}

func (f *FilterBuilder) ContainedBy(column string, value []string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "cd."+strings.Join(value, ","))
f.params[column] = "cd." + strings.Join(value, ",")
return f
}

Expand All @@ -144,7 +147,7 @@ func (f *FilterBuilder) ContainsObject(column string, value interface{}) *Filter
if err != nil {
f.client.ClientError = err
}
f.client.clientTransport.params.Add(column, "cs."+string(sum))
f.params[column] = "cs." + string(sum)
return f
}

Expand All @@ -153,37 +156,37 @@ func (f *FilterBuilder) ContainedByObject(column string, value interface{}) *Fil
if err != nil {
f.client.ClientError = err
}
f.client.clientTransport.params.Add(column, "cs."+string(sum))
f.params[column] = "cs." + string(sum)
return f
}

func (f *FilterBuilder) RangeLt(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "sl."+value)
f.params[column] = "sl." + value
return f
}

func (f *FilterBuilder) RangeGt(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "sr."+value)
f.params[column] = "sr." + value
return f
}

func (f *FilterBuilder) RangeGte(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "nxl."+value)
f.params[column] = "nxl." + value
return f
}

func (f *FilterBuilder) RangeLte(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "nxr."+value)
f.params[column] = "nxr." + value
return f
}

func (f *FilterBuilder) RangeAdjacent(column, value string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "adj."+value)
f.params[column] = "adj." + value
return f
}

func (f *FilterBuilder) Overlaps(column string, value []string) *FilterBuilder {
f.client.clientTransport.params.Add(column, "ov."+strings.Join(value, ","))
f.params[column] = "ov." + strings.Join(value, ",")
return f
}

Expand All @@ -204,6 +207,6 @@ func (f *FilterBuilder) TextSearch(column, userQuery, config, tsType string) *Fi
if config != "" {
configPart = fmt.Sprintf("(%s)", config)
}
f.client.clientTransport.params.Add(column, typePart+"fts"+configPart+"."+userQuery)
f.params[column] = typePart + "fts" + configPart + "." + userQuery
return f
}
Loading

0 comments on commit f432b53

Please sign in to comment.