Skip to content

Commit 57fdb03

Browse files
authored
fix: schema write from stdin (#557)
1 parent 445783d commit 57fdb03

File tree

4 files changed

+216
-19
lines changed

4 files changed

+216
-19
lines changed

internal/cmd/schema.go

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ import (
3030
"github.com/authzed/zed/internal/console"
3131
)
3232

33+
type termChecker interface {
34+
IsTerminal(fd int) bool
35+
}
36+
37+
type realTermChecker struct{}
38+
39+
func (rtc *realTermChecker) IsTerminal(fd int) bool {
40+
return term.IsTerminal(fd)
41+
}
42+
3343
func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
3444
schemaCmd.AddCommand(schemaCopyCmd)
3545
schemaCopyCmd.Flags().Bool("json", false, "output as JSON")
@@ -50,7 +60,19 @@ var schemaWriteCmd = &cobra.Command{
5060
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
5161
Short: "Write a schema file (.zed or stdin) to the current permissions system",
5262
ValidArgsFunction: commands.FileExtensionCompletions("zed"),
53-
RunE: schemaWriteCmdFunc,
63+
Example: `
64+
Write from a file:
65+
zed schema write schema.zed
66+
Write from stdin:
67+
cat schema.zed | zed schema write
68+
`,
69+
RunE: func(cmd *cobra.Command, args []string) error {
70+
client, err := client.NewClient(cmd)
71+
if err != nil {
72+
return err
73+
}
74+
return schemaWriteCmdImpl(cmd, args, client, &realTermChecker{})
75+
},
5476
}
5577

5678
var schemaCopyCmd = &cobra.Command{
@@ -79,7 +101,9 @@ var schemaCompileCmd = &cobra.Command{
79101
zed preview schema compile root.zed --out compiled.zed
80102
`,
81103
ValidArgsFunction: commands.FileExtensionCompletions("zed"),
82-
RunE: schemaCompileCmdFunc,
104+
RunE: func(cmd *cobra.Command, args []string) error {
105+
return schemaCompileCmdFunc(cmd, args, &realTermChecker{})
106+
},
83107
}
84108

85109
func schemaDiffCmdFunc(_ *cobra.Command, args []string) error {
@@ -196,19 +220,16 @@ func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
196220
return nil
197221
}
198222

199-
func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {
200-
intFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
223+
func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServiceClient, terminalChecker termChecker) error {
224+
stdInFd, err := safecast.ToInt(uint(os.Stdin.Fd()))
201225
if err != nil {
202226
return err
203227
}
204-
if len(args) == 0 && term.IsTerminal(intFd) {
205-
return fmt.Errorf("must provide file path or contents via stdin")
206-
}
207228

208-
client, err := client.NewClient(cmd)
209-
if err != nil {
210-
return err
229+
if len(args) == 0 && terminalChecker.IsTerminal(stdInFd) {
230+
return errors.New("must provide file path or contents via stdin")
211231
}
232+
212233
var schemaBytes []byte
213234
switch len(args) {
214235
case 1:
@@ -246,18 +267,17 @@ func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {
246267

247268
resp, err := client.WriteSchema(cmd.Context(), request)
248269
if err != nil {
249-
log.Fatal().Err(err).Msg("failed to write schema")
270+
return fmt.Errorf("failed to write schema: %w", err)
250271
}
251272
log.Trace().Interface("response", resp).Msg("wrote schema")
252273

253274
if cobrautil.MustGetBool(cmd, "json") {
254275
prettyProto, err := commands.PrettyProto(resp)
255276
if err != nil {
256-
log.Fatal().Err(err).Msg("failed to convert schema to JSON")
277+
return fmt.Errorf("failed to convert schema to JSON: %w", err)
257278
}
258279

259280
console.Println(string(prettyProto))
260-
return nil
261281
}
262282

263283
return nil
@@ -287,7 +307,7 @@ func rewriteSchema(existingSchemaText string, definitionPrefix string) (string,
287307
// If specifiedPrefix is non-empty, it is returned immediately.
288308
// If existingSchema is non-nil, it is parsed for the prefix.
289309
// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there.
290-
func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) {
310+
func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client v1.SchemaServiceClient, existingSchema *string) (string, error) {
291311
if specifiedPrefix != "" {
292312
return specifiedPrefix, nil
293313
}
@@ -340,14 +360,14 @@ func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, clien
340360

341361
// Compiles an input schema written in the new composable schema syntax
342362
// and produces it as a fully-realized schema
343-
func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error {
363+
func schemaCompileCmdFunc(cmd *cobra.Command, args []string, termChecker termChecker) error {
344364
stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
345365
if err != nil {
346366
return err
347367
}
348368
outputFilepath := cobrautil.MustGetString(cmd, "out")
349-
if outputFilepath == "" && !term.IsTerminal(stdOutFd) {
350-
return fmt.Errorf("must provide stdout or output file path")
369+
if outputFilepath == "" && !termChecker.IsTerminal(stdOutFd) {
370+
return errors.New("must provide stdout or output file path")
351371
}
352372

353373
inputFilepath := args[0]

internal/cmd/schema_test.go

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package cmd
22

33
import (
4+
"context"
5+
"errors"
46
"io/fs"
57
"os"
68
"path/filepath"
79
"testing"
810

911
"github.com/stretchr/testify/require"
12+
"google.golang.org/grpc"
1013

14+
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
1115
"github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
1216

1317
zedtesting "github.com/authzed/zed/internal/testing"
@@ -166,16 +170,184 @@ definition resource {
166170
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
167171
zedtesting.StringFlag{FlagName: "out", FlagValue: tempOutFile})
168172

169-
err := schemaCompileCmdFunc(cmd, tc.files)
173+
mockTermCheckerr := &mockTermChecker{returnVal: false}
174+
err := schemaCompileCmdFunc(cmd, tc.files, mockTermCheckerr)
170175
if tc.expectErr == nil {
171176
require.NoError(err)
172177
tempOutString, err := os.ReadFile(tempOutFile)
173178
require.NoError(err)
174179
require.Equal(tc.expectStr, string(tempOutString))
180+
// TODO re-enable after adding a test that uses stdout
181+
// require.Equal(int(os.Stdout.Fd()), mockTermCheckerr.capturedFd, "expected stdout to be checked for terminal")
175182
} else {
176183
require.Error(err)
177184
require.ErrorAs(err, &tc.expectErr)
178185
}
179186
})
180187
}
181188
}
189+
190+
func TestSchemaWrite(t *testing.T) {
191+
t.Parallel()
192+
193+
// Save original stdin
194+
oldStdin := os.Stdin
195+
t.Cleanup(func() {
196+
os.Stdin = oldStdin
197+
})
198+
199+
testCases := map[string]struct {
200+
schemaMakerFn func() ([]string, error)
201+
terminalChecker *mockTermChecker
202+
expectErr string
203+
expectSchemaWritten string
204+
}{
205+
`schema_from_file`: {
206+
schemaMakerFn: func() ([]string, error) {
207+
return []string{
208+
filepath.Join("write-schema-test", "basic.zed"),
209+
}, nil
210+
},
211+
expectSchemaWritten: `definition user {}
212+
definition resource {
213+
relation view: user
214+
permission viewer = view
215+
}`,
216+
terminalChecker: &mockTermChecker{returnVal: false},
217+
},
218+
`schema_from_stdin`: {
219+
schemaMakerFn: func() ([]string, error) {
220+
schemaContent := "definition user{}\ndefinition document { relation read: user }"
221+
pipeRead, pipeWrite, err := os.Pipe()
222+
require.NoError(t, err)
223+
os.Stdin = pipeRead
224+
_, err = pipeWrite.WriteString(schemaContent)
225+
require.NoError(t, err)
226+
err = pipeWrite.Close()
227+
require.NoError(t, err)
228+
return []string{}, nil
229+
},
230+
terminalChecker: &mockTermChecker{returnVal: false},
231+
expectSchemaWritten: "definition user{}\ndefinition document { relation read: user }",
232+
},
233+
`schema_from_stdin_but_terminal`: {
234+
schemaMakerFn: func() ([]string, error) {
235+
schemaContent := "definition user{}\ndefinition document { relation read: user }"
236+
pipeRead, pipeWrite, err := os.Pipe()
237+
require.NoError(t, err)
238+
os.Stdin = pipeRead
239+
_, err = pipeWrite.WriteString(schemaContent)
240+
require.NoError(t, err)
241+
err = pipeWrite.Close()
242+
require.NoError(t, err)
243+
return []string{}, nil
244+
},
245+
terminalChecker: &mockTermChecker{returnVal: true},
246+
expectErr: "must provide file path or contents via stdin",
247+
},
248+
`empty_schema_errors`: {
249+
schemaMakerFn: func() ([]string, error) {
250+
pipeRead, pipeWrite, err := os.Pipe()
251+
require.NoError(t, err)
252+
os.Stdin = pipeRead
253+
_, err = pipeWrite.WriteString("")
254+
require.NoError(t, err)
255+
err = pipeWrite.Close()
256+
require.NoError(t, err)
257+
return []string{}, nil
258+
},
259+
terminalChecker: &mockTermChecker{returnVal: false},
260+
expectErr: "attempted to write empty schema",
261+
},
262+
`write_failure_errors`: {
263+
schemaMakerFn: func() ([]string, error) {
264+
return []string{
265+
filepath.Join("write-schema-test", "basic.zed"),
266+
}, errors.New("write error")
267+
},
268+
terminalChecker: &mockTermChecker{returnVal: false},
269+
expectErr: "error writing schema",
270+
},
271+
}
272+
273+
for name, tc := range testCases {
274+
t.Run(name, func(t *testing.T) {
275+
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
276+
zedtesting.StringFlag{FlagName: "schema-definition-prefix", FlagValue: ""},
277+
zedtesting.BoolFlag{FlagName: "json", FlagValue: true},
278+
)
279+
280+
args, writeErr := tc.schemaMakerFn()
281+
mockWriteSchemaClientt := &mockWriteSchemaClient{}
282+
if writeErr != nil {
283+
mockWriteSchemaClientt.writeReturnsError = true
284+
}
285+
286+
err := schemaWriteCmdImpl(cmd, args, mockWriteSchemaClientt, tc.terminalChecker)
287+
288+
if tc.expectErr != "" {
289+
require.Error(t, err)
290+
require.Contains(t, err.Error(), tc.expectErr)
291+
return
292+
}
293+
294+
require.NoError(t, err)
295+
require.Equal(t, tc.expectSchemaWritten, mockWriteSchemaClientt.receivedSchema)
296+
if tc.terminalChecker.captured {
297+
require.Equal(t, int(os.Stdin.Fd()), tc.terminalChecker.capturedFd, "expected stdin to be checked for terminal")
298+
}
299+
})
300+
}
301+
}
302+
303+
type mockWriteSchemaClient struct {
304+
existingSchema string
305+
receivedSchema string
306+
writeReturnsError bool
307+
}
308+
309+
var _ v1.SchemaServiceClient = (*mockWriteSchemaClient)(nil)
310+
311+
func (m *mockWriteSchemaClient) WriteSchema(_ context.Context, in *v1.WriteSchemaRequest, _ ...grpc.CallOption) (*v1.WriteSchemaResponse, error) {
312+
if m.writeReturnsError {
313+
return nil, errors.New("error writing schema")
314+
}
315+
m.receivedSchema = in.Schema
316+
return &v1.WriteSchemaResponse{}, nil
317+
}
318+
319+
func (m *mockWriteSchemaClient) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest, _ ...grpc.CallOption) (*v1.ReadSchemaResponse, error) {
320+
return &v1.ReadSchemaResponse{
321+
SchemaText: m.existingSchema,
322+
}, nil
323+
}
324+
325+
func (m *mockWriteSchemaClient) ReflectSchema(_ context.Context, _ *v1.ReflectSchemaRequest, _ ...grpc.CallOption) (*v1.ReflectSchemaResponse, error) {
326+
panic("not implemented")
327+
}
328+
329+
func (m *mockWriteSchemaClient) ComputablePermissions(_ context.Context, _ *v1.ComputablePermissionsRequest, _ ...grpc.CallOption) (*v1.ComputablePermissionsResponse, error) {
330+
panic("not implemented")
331+
}
332+
333+
func (m *mockWriteSchemaClient) DependentRelations(_ context.Context, _ *v1.DependentRelationsRequest, _ ...grpc.CallOption) (*v1.DependentRelationsResponse, error) {
334+
panic("not implemented")
335+
}
336+
337+
func (m *mockWriteSchemaClient) DiffSchema(_ context.Context, _ *v1.DiffSchemaRequest, _ ...grpc.CallOption) (*v1.DiffSchemaResponse, error) {
338+
panic("not implemented")
339+
}
340+
341+
type mockTermChecker struct {
342+
returnVal bool
343+
captured bool
344+
capturedFd int
345+
}
346+
347+
var _ termChecker = (*mockTermChecker)(nil)
348+
349+
func (m *mockTermChecker) IsTerminal(fd int) bool {
350+
m.captured = true
351+
m.capturedFd = fd
352+
return m.returnVal
353+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
definition user {}
2+
definition resource {
3+
relation view: user
4+
permission viewer = view
5+
}

internal/commands/schema.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func schemaReadCmdFunc(cmd *cobra.Command, _ []string) error {
6868
}
6969

7070
// ReadSchema calls read schema for the client and returns the schema found.
71-
func ReadSchema(ctx context.Context, client client.Client) (string, error) {
71+
func ReadSchema(ctx context.Context, client v1.SchemaServiceClient) (string, error) {
7272
request := &v1.ReadSchemaRequest{}
7373
log.Trace().Interface("request", request).Msg("requesting schema read")
7474

0 commit comments

Comments
 (0)