Skip to content

fix: automatically set default schema env var for type generation #3243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ var (
return err
}
}
return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, postgrestV9Compat, swiftAccessControl.Value, afero.NewOsFs())
return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, setDefault, postgrestV9Compat, swiftAccessControl.Value, afero.NewOsFs())
},
Example: ` supabase gen types --local
supabase gen types --linked --lang=go
supabase gen types --project-id abc-def-123 --schema public --schema private
supabase gen types --db-url 'postgresql://...' --schema public --schema auth`,
}

setDefault bool
)

func init() {
Expand All @@ -106,6 +108,7 @@ func init() {
typeFlags.StringSliceVarP(&schema, "schema", "s", []string{}, "Comma separated list of schema to include.")
typeFlags.Var(&swiftAccessControl, "swift-access-control", "Access control for Swift generated types.")
typeFlags.BoolVar(&postgrestV9Compat, "postgrest-v9-compat", false, "Generate types compatible with PostgREST v9 and below. Only use together with --db-url.")
typeFlags.BoolVar(&setDefault, "set-default", false, "Set the specified schema as the default for helper types when using a single non-public schema")
genCmd.AddCommand(genTypesCmd)
keyFlags := genKeysCmd.Flags()
keyFlags.StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.")
Expand Down
29 changes: 20 additions & 9 deletions internal/gen/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@ const (
SwiftInternalAccessControl = "internal"
)

func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, postgrestV9Compat bool, swiftAccessControl string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, setDefault bool, postgrestV9Compat bool, swiftAccessControl string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
originalURL := utils.ToPostgresURL(dbConfig)
// Add default schemas if --schema flag is not specified
if len(schemas) == 0 {
schemas = utils.RemoveDuplicates(append([]string{"public"}, utils.Config.Api.Schemas...))
}
included := strings.Join(schemas, ",")

var defaultSchemaEnv string
if setDefault && len(schemas) == 1 && schemas[0] != "public" {
defaultSchemaEnv = schemas[0]
}
Copy link
Contributor

@sweatybridge sweatybridge May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a new flag so it doesn't break existing users that rely on the generated default schema. For eg. supabase gen types --schema private --set-default

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sweatybridge Thanks. I've implemented a set-default flag for backwards compatibility that makes the behavior now explicit and opt-in. The new behavior is now:

supabase gen types --schema private --set-default


if projectId != "" {
if lang != LangTypescript {
return errors.Errorf("Unable to generate %s types for selected project. Try using --db-url flag instead.", lang)
Expand Down Expand Up @@ -84,18 +89,24 @@ func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang str
escaped += "&sslmode=require"
}

envVars := []string{
"PG_META_DB_URL=" + escaped,
"PG_META_GENERATE_TYPES=" + lang,
"PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included,
"PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl,
fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat),
}

if defaultSchemaEnv != "" {
envVars = append(envVars, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA="+defaultSchemaEnv)
}

return utils.DockerRunOnceWithConfig(
ctx,
container.Config{
Image: utils.Config.Studio.PgmetaImage,
Env: []string{
"PG_META_DB_URL=" + escaped,
"PG_META_GENERATE_TYPES=" + lang,
"PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included,
"PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl,
fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat),
},
Cmd: []string{"node", "dist/server/server.js"},
Env: envVars,
Cmd: []string{"node", "dist/server/server.js"},
},
hostConfig,
network.NetworkingConfig{},
Expand Down
150 changes: 142 additions & 8 deletions internal/gen/types/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestGenLocalCommand(t *testing.T) {
conn := pgtest.NewConn()
defer conn.Close(t)
// Run test
assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys, conn.Intercept))
assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys, conn.Intercept))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
Expand All @@ -63,7 +63,7 @@ func TestGenLocalCommand(t *testing.T) {
Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId).
Reply(http.StatusServiceUnavailable)
// Run test
assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys))
assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
Expand All @@ -83,7 +83,7 @@ func TestGenLocalCommand(t *testing.T) {
Get("/v" + utils.Docker.ClientVersion() + "/images").
Reply(http.StatusServiceUnavailable)
// Run test
assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys))
assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
Expand All @@ -106,7 +106,7 @@ func TestGenLocalCommand(t *testing.T) {
conn := pgtest.NewConn()
defer conn.Close(t)
// Run test
assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, true, SwiftInternalAccessControl, fsys, conn.Intercept))
assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, false, true, SwiftInternalAccessControl, fsys, conn.Intercept))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
Expand All @@ -129,7 +129,7 @@ func TestGenLinkedCommand(t *testing.T) {
Reply(200).
JSON(api.TypescriptResponse{Types: ""})
// Run test
assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys))
assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
Expand All @@ -144,7 +144,7 @@ func TestGenLinkedCommand(t *testing.T) {
Get("/v1/projects/" + projectId + "/types/typescript").
ReplyError(errNetwork)
// Run test
err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys)
err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys)
// Validate api
assert.ErrorIs(t, err, errNetwork)
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand All @@ -159,7 +159,7 @@ func TestGenLinkedCommand(t *testing.T) {
Get("/v1/projects/" + projectId + "/types/typescript").
Reply(http.StatusServiceUnavailable)
// Run test
assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys))
assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys))
})
}

Expand All @@ -184,8 +184,142 @@ func TestGenRemoteCommand(t *testing.T) {
conn := pgtest.NewConn()
defer conn.Close(t)
// Run test
assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, "", afero.NewMemMapFs(), conn.Intercept))
assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, false, true, "", afero.NewMemMapFs(), conn.Intercept))
// Validate api
assert.Empty(t, apitest.ListUnmatchedRequests())
})
}

func TestGenWithSetDefault(t *testing.T) {
utils.DbId = "test-db"
utils.Config.Hostname = "localhost"
utils.Config.Db.Port = 5432

dbConfig := pgconn.Config{
Host: utils.Config.Hostname,
Port: utils.Config.Db.Port,
User: "admin",
Password: "password",
}

t.Run("sets default schema env var with single non-public schema", func(t *testing.T) {
const containerId = "test-pgmeta"
imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage)
fsys := afero.NewMemMapFs()

require.NoError(t, apitest.MockDocker(utils.Docker))
defer gock.OffAll()

gock.New(utils.Docker.DaemonHost()).
Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId).
Reply(http.StatusOK).
JSON(container.InspectResponse{})

var capturedEnv []string
apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv)
require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n"))

conn := pgtest.NewConn()
defer conn.Close(t)

err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"private"}, true, true, "", fsys, conn.Intercept)
assert.NoError(t, err)

found := false
for _, env := range capturedEnv {
if env == "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA=private" {
found = true
break
}
}
assert.True(t, found, "Expected PG_META_GENERATE_TYPES_DEFAULT_SCHEMA=private to be set in environment variables")
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("does not set default schema env var without flag", func(t *testing.T) {
const containerId = "test-pgmeta"
imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage)
fsys := afero.NewMemMapFs()

require.NoError(t, apitest.MockDocker(utils.Docker))
defer gock.OffAll()

gock.New(utils.Docker.DaemonHost()).
Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId).
Reply(http.StatusOK).
JSON(container.InspectResponse{})

var capturedEnv []string
apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv)
require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n"))

conn := pgtest.NewConn()
defer conn.Close(t)

err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"private"}, false, true, "", fsys, conn.Intercept)
assert.NoError(t, err)

for _, env := range capturedEnv {
assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var when flag is false")
}
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("does not set default schema env var with multiple schemas", func(t *testing.T) {
const containerId = "test-pgmeta"
imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage)
fsys := afero.NewMemMapFs()

require.NoError(t, apitest.MockDocker(utils.Docker))
defer gock.OffAll()

gock.New(utils.Docker.DaemonHost()).
Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId).
Reply(http.StatusOK).
JSON(container.InspectResponse{})

var capturedEnv []string
apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv)
require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n"))

conn := pgtest.NewConn()
defer conn.Close(t)

err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"public", "private"}, true, true, "", fsys, conn.Intercept)
assert.NoError(t, err)

for _, env := range capturedEnv {
assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var with multiple schemas")
}
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("does not set default schema env var with public schema", func(t *testing.T) {
const containerId = "test-pgmeta"
imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage)
fsys := afero.NewMemMapFs()

require.NoError(t, apitest.MockDocker(utils.Docker))
defer gock.OffAll()

gock.New(utils.Docker.DaemonHost()).
Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId).
Reply(http.StatusOK).
JSON(container.InspectResponse{})

var capturedEnv []string
apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv)
require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n"))

conn := pgtest.NewConn()
defer conn.Close(t)

err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, true, "", fsys, conn.Intercept)
assert.NoError(t, err)

for _, env := range capturedEnv {
assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var with public schema")
}
assert.Empty(t, apitest.ListUnmatchedRequests())
})
}
36 changes: 36 additions & 0 deletions internal/testing/apitest/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apitest

import (
"bytes"
"encoding/json"
"fmt"
"net/http"

Expand Down Expand Up @@ -113,6 +114,41 @@ func MockDockerLogsExitCode(docker *client.Client, containerID string, exitCode
return setupDockerLogs(docker, containerID, "", exitCode)
}

// MockDockerStartWithEnvCapture extends MockDockerStart to capture environment variables
// passed to container creation. This is useful for testing environment variable logic.
func MockDockerStartWithEnvCapture(docker *client.Client, imageID, containerID string, capturedEnv *[]string) {
gock.New(docker.DaemonHost()).
Get("/v" + docker.ClientVersion() + "/images/" + imageID + "/json").
Reply(http.StatusOK).
JSON(image.InspectResponse{})
gock.New(docker.DaemonHost()).
Post("/v" + docker.ClientVersion() + "/networks/create").
Reply(http.StatusCreated).
JSON(network.CreateResponse{})
gock.New(docker.DaemonHost()).
Post("/v" + docker.ClientVersion() + "/volumes/create").
Persist().
Reply(http.StatusCreated).
JSON(volume.Volume{})
gock.New(docker.DaemonHost()).
Post("/v" + docker.ClientVersion() + "/containers/create").
AddMatcher(func(req *http.Request, ereq *gock.Request) (bool, error) {
var config struct {
Env []string `json:"Env"`
}
if err := json.NewDecoder(req.Body).Decode(&config); err != nil {
return false, err
}
*capturedEnv = config.Env
return true, nil
}).
Reply(http.StatusOK).
JSON(container.CreateResponse{ID: containerID})
gock.New(docker.DaemonHost()).
Post("/v" + docker.ClientVersion() + "/containers/" + containerID + "/start").
Reply(http.StatusAccepted)
}

func ListUnmatchedRequests() []string {
result := make([]string, len(gock.GetUnmatchedRequests()))
for i, r := range gock.GetUnmatchedRequests() {
Expand Down