diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index deb9b0f7ba..f02876fbf6 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "errors" "strings" "testing" @@ -522,7 +523,7 @@ func TestCollection(t *testing.T) { }) mt.Run("nil id", func(mt *mtest.T) { _, err := mt.Coll.UpdateByID(context.Background(), nil, bson.D{{"$inc", bson.D{{"x", 1}}}}) - assert.Equal(mt, err, mongo.ErrNilValue, "expected %v, got %v", mongo.ErrNilValue, err) + assert.True(mt, errors.Is(err, mongo.ErrNilValue), "expected %v, got %v", mongo.ErrNilValue, err) }) mt.RunOpts("found", noClientOpts, func(mt *mtest.T) { testCases := []struct { diff --git a/mongo/client.go b/mongo/client.go index 09535f2ba6..61e666b1fe 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -888,7 +888,7 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, } if len(writes) == 0 { - return nil, ErrEmptySlice + return nil, fmt.Errorf("invalid writes: %w", ErrEmptySlice) } bwo, err := mongoutil.NewOptions(opts...) if err != nil { diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index ab2d8b6acc..ca6ecf5240 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -50,11 +50,11 @@ type clientBulkWrite struct { func (bw *clientBulkWrite) execute(ctx context.Context) error { if len(bw.writePairs) == 0 { - return ErrEmptySlice + return fmt.Errorf("invalid writes: %w", ErrEmptySlice) } - for _, m := range bw.writePairs { + for i, m := range bw.writePairs { if m.model == nil { - return ErrNilDocument + return fmt.Errorf("error from model at index %d: %w", i, ErrNilDocument) } } batches := &modelBatches{ diff --git a/mongo/client_test.go b/mongo/client_test.go index e4c79c26a2..204aa52a1e 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -81,10 +81,10 @@ func TestClient(t *testing.T) { assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err) _, err = client.ListDatabases(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = client.ListDatabaseNames(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) }) t.Run("read preference", func(t *testing.T) { t.Run("absent", func(t *testing.T) { diff --git a/mongo/collection.go b/mongo/collection.go index a9279f1381..b07b8b173c 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -193,7 +193,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (*BulkWriteResult, error) { if len(models) == 0 { - return nil, ErrEmptySlice + return nil, fmt.Errorf("invalid models: %w", ErrEmptySlice) } if ctx == nil { @@ -221,9 +221,9 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, selector := makePinnedSelector(sess, coll.writeSelector) - for _, model := range models { + for i, model := range models { if model == nil { - return nil, ErrNilDocument + return nil, fmt.Errorf("invalid model at index %d: %w", i, ErrNilDocument) } } @@ -407,10 +407,10 @@ func (coll *Collection) InsertMany( dv := reflect.ValueOf(documents) if dv.Kind() != reflect.Slice { - return nil, ErrNotSlice + return nil, fmt.Errorf("invalid documents: %w", ErrNotSlice) } if dv.Len() == 0 { - return nil, ErrEmptySlice + return nil, fmt.Errorf("invalid documents: %w", ErrEmptySlice) } docSlice := make([]interface{}, 0, dv.Len()) @@ -729,7 +729,7 @@ func (coll *Collection) UpdateByID( opts ...options.Lister[options.UpdateOneOptions], ) (*UpdateResult, error) { if id == nil { - return nil, ErrNilValue + return nil, fmt.Errorf("invalid id: %w", ErrNilValue) } return coll.UpdateOne(ctx, bson.D{{"_id", id}}, update, opts...) } diff --git a/mongo/collection_test.go b/mongo/collection_test.go index 268f1c8c0e..23a1209e11 100644 --- a/mongo/collection_test.go +++ b/mongo/collection_test.go @@ -150,76 +150,76 @@ func TestCollection(t *testing.T) { doc := bson.D{} _, err := coll.InsertOne(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.InsertMany(bgCtx, nil) - assert.Equal(t, ErrNotSlice, err, "expected error %v, got %v", ErrNotSlice, err) + assert.True(t, errors.Is(err, ErrNotSlice), "expected error %v, got %v", ErrNotSlice, err) _, err = coll.InsertMany(bgCtx, []interface{}{}) - assert.Equal(t, ErrEmptySlice, err, "expected error %v, got %v", ErrEmptySlice, err) + assert.True(t, errors.Is(err, ErrEmptySlice), "expected error %v, got %v", ErrEmptySlice, err) _, err = coll.InsertMany(bgCtx, "x") - assert.Equal(t, ErrNotSlice, err, "expected error %v, got %v", ErrNotSlice, err) + assert.True(t, errors.Is(err, ErrNotSlice), "expected error %v, got %v", ErrNotSlice, err) _, err = coll.DeleteOne(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.DeleteMany(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.UpdateOne(bgCtx, nil, doc) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.UpdateOne(bgCtx, doc, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.UpdateMany(bgCtx, nil, doc) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.UpdateMany(bgCtx, doc, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.ReplaceOne(bgCtx, nil, doc) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.ReplaceOne(bgCtx, doc, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.CountDocuments(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.Distinct(bgCtx, "x", nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.Find(bgCtx, nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOne(bgCtx, nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOneAndDelete(bgCtx, nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOneAndReplace(bgCtx, nil, doc).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOneAndReplace(bgCtx, doc, nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOneAndUpdate(bgCtx, nil, doc).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) err = coll.FindOneAndUpdate(bgCtx, doc, nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = coll.BulkWrite(bgCtx, nil) - assert.Equal(t, ErrEmptySlice, err, "expected error %v, got %v", ErrEmptySlice, err) + assert.True(t, errors.Is(err, ErrEmptySlice), "expected error %v, got %v", ErrEmptySlice, err) _, err = coll.BulkWrite(bgCtx, []WriteModel{}) - assert.Equal(t, ErrEmptySlice, err, "expected error %v, got %v", ErrEmptySlice, err) + assert.True(t, errors.Is(err, ErrEmptySlice), "expected error %v, got %v", ErrEmptySlice, err) _, err = coll.BulkWrite(bgCtx, []WriteModel{nil}) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) aggErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") _, err = coll.Aggregate(bgCtx, nil) diff --git a/mongo/cursor.go b/mongo/cursor.go index 5134c4fed4..ee0e848c64 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -97,7 +97,7 @@ func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registr for i, doc := range documents { switch t := doc.(type) { case nil: - return nil, ErrNilDocument + return nil, fmt.Errorf("invalid document at index %d: %w", i, ErrNilDocument) case []byte: // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. doc = bson.Raw(t) diff --git a/mongo/database_test.go b/mongo/database_test.go index 0c324b9b09..bce78d197e 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -141,16 +141,16 @@ func TestDatabase(t *testing.T) { db := setupDb("foo") err := db.RunCommand(bgCtx, nil).Err() - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = db.Watch(context.Background(), nil) watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err) _, err = db.ListCollections(context.Background(), nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) _, err = db.ListCollectionNames(context.Background(), nil) - assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) + assert.True(t, errors.Is(err, ErrNilDocument), "expected error %v, got %v", ErrNilDocument, err) }) } diff --git a/mongo/errors.go b/mongo/errors.go index 07d713fd43..42f7281fb3 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -24,17 +24,35 @@ import ( // ErrClientDisconnected is returned when disconnected Client is used to run an operation. var ErrClientDisconnected = errors.New("client is disconnected") +// InvalidArgumentError wraps an invalid argument error. +type InvalidArgumentError struct { + wrapped error +} + +// Error implements the error interface. +func (e InvalidArgumentError) Error() string { + return e.wrapped.Error() +} + +// Unwrap returns the underlying error. +func (e InvalidArgumentError) Unwrap() error { + return e.wrapped +} + +// ErrMultipleIndexDrop is returned if multiple indexes would be dropped from a call to IndexView.DropOne. +var ErrMultipleIndexDrop error = InvalidArgumentError{errors.New("multiple indexes would be dropped")} + // ErrNilDocument is returned when a nil document is passed to a CRUD method. -var ErrNilDocument = errors.New("document is nil") +var ErrNilDocument error = InvalidArgumentError{errors.New("document is nil")} // ErrNilValue is returned when a nil value is passed to a CRUD method. -var ErrNilValue = errors.New("value is nil") +var ErrNilValue error = InvalidArgumentError{errors.New("value is nil")} // ErrEmptySlice is returned when an empty slice is passed to a CRUD method that requires a non-empty slice. -var ErrEmptySlice = errors.New("must provide at least one element in input slice") +var ErrEmptySlice error = InvalidArgumentError{errors.New("must provide at least one element in input slice")} // ErrNotSlice is returned when a type other than slice is passed to InsertMany. -var ErrNotSlice = errors.New("must provide a non-empty slice") +var ErrNotSlice error = InvalidArgumentError{errors.New("must provide a non-empty slice")} // ErrMapForOrderedArgument is returned when a map with multiple keys is passed to a CRUD method for an ordered parameter type ErrMapForOrderedArgument struct { diff --git a/mongo/index_view.go b/mongo/index_view.go index 748957da1b..de322cf60b 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -29,11 +29,10 @@ import ( var ErrInvalidIndexValue = errors.New("invalid index value") // ErrNonStringIndexName is returned if an index is created with a name that is not a string. +// +// Deprecated: it will be removed in the next major release var ErrNonStringIndexName = errors.New("index name must be a string") -// ErrMultipleIndexDrop is returned if multiple indexes would be dropped from a call to IndexView.DropOne. -var ErrMultipleIndexDrop = errors.New("multiple indexes would be dropped") - // IndexView is a type that can be used to create, drop, and list indexes on a collection. An IndexView for a collection // can be created by a call to Collection.Indexes(). type IndexView struct {