From 50f8562d79e260c7e7feeaa32beeeee2ea8c3939 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Wed, 27 Nov 2024 15:01:58 -0500 Subject: [PATCH] Ensure caveats are read in bulk import --- internal/services/v1/experimental.go | 8 ++ internal/services/v1/experimental_test.go | 112 +++++++++++--------- internal/services/v1/permissions.go | 4 +- internal/services/v1/permissions_test.go | 120 +++++++++++++--------- 4 files changed, 149 insertions(+), 95 deletions(-) diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index 8cd1058844..d2dddb4115 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -172,6 +172,14 @@ func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { return nil, nil } + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context + a.current.OptionalCaveat = &a.caveat + } else { + a.current.OptionalCaveat = nil + } + if a.caveat.CaveatName != "" { a.current.OptionalCaveat = &a.caveat } else { diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index b098909f51..ecb123cf04 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -53,63 +53,79 @@ func TestBulkImportRelationships(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - t.Parallel() - require := require.New(t) + for _, withCaveats := range []bool{true, false} { + withCaveats := withCaveats + t.Run(fmt.Sprintf("withCaveats=%t", withCaveats), func(t *testing.T) { + require := require.New(t) - conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) - client := v1.NewExperimentalServiceClient(conn) - t.Cleanup(cleanup) + conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) + client := v1.NewExperimentalServiceClient(conn) + t.Cleanup(cleanup) - ctx := context.Background() + ctx := context.Background() - writer, err := client.BulkImportRelationships(ctx) - require.NoError(err) - - var expectedTotal uint64 - for batchNum := 0; batchNum < tc.numBatches; batchNum++ { - batchSize := tc.batchSize() - batch := make([]*v1.Relationship, 0, batchSize) - - for i := uint64(0); i < batchSize; i++ { - batch = append(batch, rel( - tf.DocumentNS.Name, - strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), - "viewer", - tf.UserNS.Name, - strconv.FormatUint(i, 10), - "", - )) - } + writer, err := client.BulkImportRelationships(ctx) + require.NoError(err) - err := writer.Send(&v1.BulkImportRelationshipsRequest{ - Relationships: batch, - }) - require.NoError(err) + var expectedTotal uint64 + for batchNum := 0; batchNum < tc.numBatches; batchNum++ { + batchSize := tc.batchSize() + batch := make([]*v1.Relationship, 0, batchSize) + + for i := uint64(0); i < batchSize; i++ { + if withCaveats { + batch = append(batch, relWithCaveat( + tf.DocumentNS.Name, + strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), + "caveated_viewer", + tf.UserNS.Name, + strconv.FormatUint(i, 10), + "", + "test", + )) + } else { + batch = append(batch, rel( + tf.DocumentNS.Name, + strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), + "viewer", + tf.UserNS.Name, + strconv.FormatUint(i, 10), + "", + )) + } + } + + err := writer.Send(&v1.BulkImportRelationshipsRequest{ + Relationships: batch, + }) + require.NoError(err) - expectedTotal += batchSize - } + expectedTotal += batchSize + } - resp, err := writer.CloseAndRecv() - require.NoError(err) - require.Equal(expectedTotal, resp.NumLoaded) + resp, err := writer.CloseAndRecv() + require.NoError(err) + require.Equal(expectedTotal, resp.NumLoaded) - readerClient := v1.NewPermissionsServiceClient(conn) - stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ - RelationshipFilter: &v1.RelationshipFilter{ - ResourceType: tf.DocumentNS.Name, - }, - Consistency: &v1.Consistency{ - Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}, - }, - }) - require.NoError(err) + readerClient := v1.NewPermissionsServiceClient(conn) + stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ + RelationshipFilter: &v1.RelationshipFilter{ + ResourceType: tf.DocumentNS.Name, + }, + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}, + }, + }) + require.NoError(err) - var readBack uint64 - for _, err = stream.Recv(); err == nil; _, err = stream.Recv() { - readBack++ + var readBack uint64 + for _, err = stream.Recv(); err == nil; _, err = stream.Recv() { + readBack++ + } + require.ErrorIs(err, io.EOF) + require.Equal(expectedTotal, readBack) + }) } - require.ErrorIs(err, io.EOF) - require.Equal(expectedTotal, readBack) }) } } diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 8dd1b2cff5..c5a471770f 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -768,7 +768,9 @@ func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { return nil, nil } - if a.caveat.CaveatName != "" { + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context a.current.OptionalCaveat = &a.caveat } else { a.current.OptionalCaveat = nil diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 5c8e991a8d..12b150dc9e 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -2066,62 +2066,90 @@ func TestImportBulkRelationships(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - require := require.New(t) + for _, withCaveats := range []bool{true, false} { + withCaveats := withCaveats + t.Run(fmt.Sprintf("withCaveats=%t", withCaveats), func(t *testing.T) { + require := require.New(t) - conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) - client := v1.NewPermissionsServiceClient(conn) - t.Cleanup(cleanup) + conn, cleanup, _, _ := testserver.NewTestServer(require, 0, memdb.DisableGC, true, tf.StandardDatastoreWithSchema) + client := v1.NewPermissionsServiceClient(conn) + t.Cleanup(cleanup) - ctx := context.Background() + ctx := context.Background() - writer, err := client.ImportBulkRelationships(ctx) - require.NoError(err) + writer, err := client.ImportBulkRelationships(ctx) + require.NoError(err) - var expectedTotal uint64 - for batchNum := 0; batchNum < tc.numBatches; batchNum++ { - batchSize := tc.batchSize() - batch := make([]*v1.Relationship, 0, batchSize) - - for i := uint64(0); i < batchSize; i++ { - batch = append(batch, rel( - tf.DocumentNS.Name, - strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), - "viewer", - tf.UserNS.Name, - strconv.FormatUint(i, 10), - "", - )) - } + var expectedTotal uint64 + for batchNum := 0; batchNum < tc.numBatches; batchNum++ { + batchSize := tc.batchSize() + batch := make([]*v1.Relationship, 0, batchSize) + + for i := uint64(0); i < batchSize; i++ { + if withCaveats { + batch = append(batch, relWithCaveat( + tf.DocumentNS.Name, + strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), + "caveated_viewer", + tf.UserNS.Name, + strconv.FormatUint(i, 10), + "", + "test", + )) + } else { + batch = append(batch, rel( + tf.DocumentNS.Name, + strconv.Itoa(batchNum)+"_"+strconv.FormatUint(i, 10), + "viewer", + tf.UserNS.Name, + strconv.FormatUint(i, 10), + "", + )) + } + } - err := writer.Send(&v1.ImportBulkRelationshipsRequest{ - Relationships: batch, - }) - require.NoError(err) + err := writer.Send(&v1.ImportBulkRelationshipsRequest{ + Relationships: batch, + }) + require.NoError(err) - expectedTotal += batchSize - } + expectedTotal += batchSize + } - resp, err := writer.CloseAndRecv() - require.NoError(err) - require.Equal(expectedTotal, resp.NumLoaded) + resp, err := writer.CloseAndRecv() + require.NoError(err) + require.Equal(expectedTotal, resp.NumLoaded) - readerClient := v1.NewPermissionsServiceClient(conn) - stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ - RelationshipFilter: &v1.RelationshipFilter{ - ResourceType: tf.DocumentNS.Name, - }, - Consistency: &v1.Consistency{ - Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}, - }, - }) - require.NoError(err) + readerClient := v1.NewPermissionsServiceClient(conn) + stream, err := readerClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ + RelationshipFilter: &v1.RelationshipFilter{ + ResourceType: tf.DocumentNS.Name, + }, + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}, + }, + }) + require.NoError(err) - var readBack uint64 - for _, err = stream.Recv(); err == nil; _, err = stream.Recv() { - readBack++ + var readBack uint64 + var res *v1.ReadRelationshipsResponse + for _, err = stream.Recv(); err == nil; res, err = stream.Recv() { + readBack++ + if res == nil { + continue + } + + if withCaveats { + require.NotNil(res.Relationship.OptionalCaveat) + require.Equal("test", res.Relationship.OptionalCaveat.CaveatName) + } else { + require.Nil(res.Relationship.OptionalCaveat) + } + } + require.ErrorIs(err, io.EOF) + require.Equal(expectedTotal, readBack) + }) } - require.ErrorIs(err, io.EOF) - require.Equal(expectedTotal, readBack) }) } }