diff --git a/bson/bson.go b/bson/bson.go index 7fb7f8cae..733e94a65 100644 --- a/bson/bson.go +++ b/bson/bson.go @@ -34,6 +34,7 @@ package bson import ( "bytes" + "context" "crypto/md5" "crypto/rand" "encoding/binary" @@ -64,6 +65,10 @@ type Getter interface { GetBSON() (interface{}, error) } +type GetterCtx interface { + GetBSONWithContext(context.Context) (interface{}, error) +} + // A value implementing the bson.Setter interface will receive the BSON // value via the SetBSON method during unmarshaling, and the object // itself will not be changed as usual. @@ -95,6 +100,10 @@ type Setter interface { SetBSON(raw Raw) error } +type SetterCtx interface { + SetBSONWithContext(ctx context.Context, raw Raw) error +} + // SetZero may be returned from a SetBSON method to have the value set to // its respective zero value. When used in pointer values, this will set the // field to nil rather than to the pre-allocated value. @@ -279,7 +288,7 @@ var nullBytes = []byte("null") func (id *ObjectId) UnmarshalJSON(data []byte) error { if len(data) > 0 && (data[0] == '{' || data[0] == 'O') { var v struct { - Id json.RawMessage `json:"$oid"` + Id json.RawMessage `json:"$oid"` Func struct { Id json.RawMessage } `json:"$oidFunc"` @@ -505,13 +514,17 @@ func handleErr(err *error) { // F int64 "myf,omitempty,minsize" // } // -func Marshal(in interface{}) (out []byte, err error) { +func MarshalWithContext(ctx context.Context, in interface{}) (out []byte, err error) { defer handleErr(&err) e := &encoder{make([]byte, 0, initialBufferSize)} - e.addDoc(reflect.ValueOf(in)) + e.addDoc(ctx, reflect.ValueOf(in)) return e.out, nil } +func Marshal(in interface{}) (out []byte, err error) { + return MarshalWithContext(context.TODO(), in) +} + // Unmarshal deserializes data from in into the out value. The out value // must be a map, a pointer to a struct, or a pointer to a bson.D value. // In the case of struct values, only exported fields will be deserialized. @@ -547,7 +560,7 @@ func Marshal(in interface{}) (out []byte, err error) { // silently skipped. // // Pointer values are initialized when necessary. -func Unmarshal(in []byte, out interface{}) (err error) { +func UnmarshalWithContext(ctx context.Context, in []byte, out interface{}) (err error) { if raw, ok := out.(*Raw); ok { raw.Kind = 3 raw.Data = in @@ -560,7 +573,7 @@ func Unmarshal(in []byte, out interface{}) (err error) { fallthrough case reflect.Map: d := newDecoder(in) - d.readDocTo(v) + d.readDocTo(ctx, v) case reflect.Struct: return errors.New("Unmarshal can't deal with struct values. Use a pointer.") default: @@ -569,12 +582,16 @@ func Unmarshal(in []byte, out interface{}) (err error) { return nil } +func Unmarshal(in []byte, out interface{}) (err error) { + return UnmarshalWithContext(context.TODO(), in, out) +} + // Unmarshal deserializes raw into the out value. If the out value type // is not compatible with raw, a *bson.TypeError is returned. // // See the Unmarshal function documentation for more details on the // unmarshalling process. -func (raw Raw) Unmarshal(out interface{}) (err error) { +func (raw Raw) UnmarshalWithContext(ctx context.Context, out interface{}) (err error) { defer handleErr(&err) v := reflect.ValueOf(out) switch v.Kind() { @@ -583,7 +600,7 @@ func (raw Raw) Unmarshal(out interface{}) (err error) { fallthrough case reflect.Map: d := newDecoder(raw.Data) - good := d.readElemTo(v, raw.Kind) + good := d.readElemTo(ctx, v, raw.Kind) if !good { return &TypeError{v.Type(), raw.Kind} } @@ -595,6 +612,10 @@ func (raw Raw) Unmarshal(out interface{}) (err error) { return nil } +func (raw Raw) Unmarshal(out interface{}) (err error) { + return raw.UnmarshalWithContext(context.TODO(), out) +} + type TypeError struct { Type reflect.Type Kind byte diff --git a/bson/context.go b/bson/context.go new file mode 100644 index 000000000..c81118e8e --- /dev/null +++ b/bson/context.go @@ -0,0 +1,89 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. +// + +package bson + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +type bsonOptions struct { + // skipCustom is used by + // - decode.go to skip looking for a custom SetBSON() or SetBSONWithContext() + // - encode.go to skip looking for a custom GetBSON() or GetBSONWithContext() + // for any type with the base type name specified by 'skipCustom'. + // This is useful to avoid infinite loop caused by: + // - calling Unmarshal from custom SetBSON function (decode.go) + // - encode.go calling custom GetBSON after just calling custom GetBSON for a given type + skipCustom string +} + +type key int + +var bsonKey key = 0 + +// Returns the topmost bsonOptions value stored in ctx, if any. +func fromContext(ctx context.Context) (*bsonOptions, bool) { + if ctx == nil { + return nil, false + } + opts, ok := ctx.Value(bsonKey).(*bsonOptions) + return opts, ok +} + +// Returns the base type name (type name without a prefix that contains any combination of * or []). +func baseTypeName(typ reflect.Type) string { + return strings.Trim(fmt.Sprintf("%v", typ), "*[]") +} + +// Creates a new context with a value for skipCustom based on base type name of valu. +func NewContextWithSkipCustom(ctx context.Context, valu interface{}) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, bsonKey, &bsonOptions{skipCustom: baseTypeName(reflect.TypeOf(valu))}) +} + +// IsSkipCustom is useful to avoid infinite loop caused by: +// - calling Unmarshal from custom SetBSON function (decode.go) +// - encode.go calling custom GetBSON after just calling custom GetBSON for a given type +// +// Returns true if base type name of typ is the same as skipCustom. +// +// This method is used to skip all custom SetBSON/GetBSON functions of all types with the same base type. +// Note: if goal is to skip the custom functions of certain variations of a base type, +// skipCustom will not work (it will skip all variants). +func IsSkipCustom(ctx context.Context, typ reflect.Type) bool { + if opts, _ := fromContext(ctx); opts != nil { + return opts.skipCustom == baseTypeName(typ) + } + return false +} diff --git a/bson/decode.go b/bson/decode.go index 7c2d8416a..a8324897a 100644 --- a/bson/decode.go +++ b/bson/decode.go @@ -28,10 +28,12 @@ package bson import ( + "context" "fmt" "math" "net/url" "reflect" + "runtime" "strconv" "sync" "time" @@ -71,26 +73,38 @@ const ( setterNone setterType setterAddr + setterCtxType + setterCtxAddr ) var setterStyles map[reflect.Type]int var setterIface reflect.Type +var setterCtxIface reflect.Type var setterMutex sync.RWMutex func init() { var iface Setter + var ifaceCtx SetterCtx setterIface = reflect.TypeOf(&iface).Elem() + setterCtxIface = reflect.TypeOf(&ifaceCtx).Elem() setterStyles = make(map[reflect.Type]int) } -func setterStyle(outt reflect.Type) int { +func setterStyle(ctx context.Context, outt reflect.Type) int { + if IsSkipCustom(ctx, outt) { + return setterNone + } setterMutex.RLock() style := setterStyles[outt] setterMutex.RUnlock() if style == setterUnknown { setterMutex.Lock() defer setterMutex.Unlock() - if outt.Implements(setterIface) { + if outt.Implements(setterCtxIface) { + setterStyles[outt] = setterCtxType + } else if reflect.PtrTo(outt).Implements(setterCtxIface) { + setterStyles[outt] = setterCtxAddr + } else if outt.Implements(setterIface) { setterStyles[outt] = setterType } else if reflect.PtrTo(outt).Implements(setterIface) { setterStyles[outt] = setterAddr @@ -102,9 +116,9 @@ func setterStyle(outt reflect.Type) int { return style } -func getSetter(outt reflect.Type, out reflect.Value) Setter { - style := setterStyle(outt) - if style == setterNone { +func getSetter(ctx context.Context, outt reflect.Type, out reflect.Value) Setter { + style := setterStyle(ctx, outt) + if style != setterType && style != setterAddr { return nil } if style == setterAddr { @@ -118,6 +132,22 @@ func getSetter(outt reflect.Type, out reflect.Value) Setter { return out.Interface().(Setter) } +func getSetterCtx(ctx context.Context, outt reflect.Type, out reflect.Value) SetterCtx { + style := setterStyle(ctx, outt) + if style != setterCtxType && style != setterCtxAddr { + return nil + } + if style == setterCtxAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(SetterCtx) +} + func clearMap(m reflect.Value) { var none reflect.Value for _, k := range m.MapKeys() { @@ -125,7 +155,7 @@ func clearMap(m reflect.Value) { } } -func (d *decoder) readDocTo(out reflect.Value) { +func (d *decoder) readDocTo(ctx context.Context, out reflect.Value) { var elemType reflect.Type outt := out.Type() outk := outt.Kind() @@ -134,9 +164,18 @@ func (d *decoder) readDocTo(out reflect.Value) { if outk == reflect.Ptr && out.IsNil() { out.Set(reflect.New(outt.Elem())) } - if setter := getSetter(outt, out); setter != nil { + if setterCtx := getSetterCtx(ctx, outt, out); setterCtx != nil { + var raw Raw + d.readDocTo(ctx, reflect.ValueOf(&raw)) + err := setterCtx.SetBSONWithContext(ctx, raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if setter := getSetter(ctx, outt, out); setter != nil { var raw Raw - d.readDocTo(reflect.ValueOf(&raw)) + d.readDocTo(ctx, reflect.ValueOf(&raw)) err := setter.SetBSON(raw) if _, ok := err.(*TypeError); err != nil && !ok { panic(err) @@ -214,10 +253,10 @@ func (d *decoder) readDocTo(out reflect.Value) { case reflect.Slice: switch outt.Elem() { case typeDocElem: - origout.Set(d.readDocElems(outt)) + origout.Set(d.readDocElems(ctx, outt)) return case typeRawDocElem: - origout.Set(d.readRawDocElems(outt)) + origout.Set(d.readRawDocElems(ctx, outt)) return } fallthrough @@ -240,7 +279,7 @@ func (d *decoder) readDocTo(out reflect.Value) { switch outk { case reflect.Map: e := reflect.New(elemType).Elem() - if d.readElemTo(e, kind) { + if d.readElemTo(ctx, e, kind) { k := reflect.ValueOf(name) if convertKey { k = k.Convert(keyType) @@ -249,24 +288,24 @@ func (d *decoder) readDocTo(out reflect.Value) { } case reflect.Struct: if outt == typeRaw { - d.dropElem(kind) + d.dropElem(ctx, kind) } else { if info, ok := fieldsMap[name]; ok { if info.Inline == nil { - d.readElemTo(out.Field(info.Num), kind) + d.readElemTo(ctx, out.Field(info.Num), kind) } else { - d.readElemTo(out.FieldByIndex(info.Inline), kind) + d.readElemTo(ctx, out.FieldByIndex(info.Inline), kind) } } else if inlineMap.IsValid() { if inlineMap.IsNil() { inlineMap.Set(reflect.MakeMap(inlineMap.Type())) } e := reflect.New(elemType).Elem() - if d.readElemTo(e, kind) { + if d.readElemTo(ctx, e, kind) { inlineMap.SetMapIndex(reflect.ValueOf(name), e) } } else { - d.dropElem(kind) + d.dropElem(ctx, kind) } } case reflect.Slice: @@ -287,7 +326,7 @@ func (d *decoder) readDocTo(out reflect.Value) { } } -func (d *decoder) readArrayDocTo(out reflect.Value) { +func (d *decoder) readArrayDocTo(ctx context.Context, out reflect.Value) { end := int(d.readInt32()) end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { @@ -307,7 +346,7 @@ func (d *decoder) readArrayDocTo(out reflect.Value) { corrupted() } d.i++ - d.readElemTo(out.Index(i), kind) + d.readElemTo(ctx, out.Index(i), kind) if d.i >= end { corrupted() } @@ -322,11 +361,11 @@ func (d *decoder) readArrayDocTo(out reflect.Value) { } } -func (d *decoder) readSliceDoc(t reflect.Type) interface{} { +func (d *decoder) readSliceDoc(ctx context.Context, t reflect.Type) interface{} { tmp := make([]reflect.Value, 0, 8) elemType := t.Elem() if elemType == typeRawDocElem { - d.dropElem(0x04) + d.dropElem(ctx, 0x04) return reflect.Zero(t).Interface() } @@ -345,7 +384,7 @@ func (d *decoder) readSliceDoc(t reflect.Type) interface{} { } d.i++ e := reflect.New(elemType).Elem() - if d.readElemTo(e, kind) { + if d.readElemTo(ctx, e, kind) { tmp = append(tmp, e) } if d.i >= end { @@ -368,14 +407,14 @@ func (d *decoder) readSliceDoc(t reflect.Type) interface{} { var typeSlice = reflect.TypeOf([]interface{}{}) var typeIface = typeSlice.Elem() -func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { +func (d *decoder) readDocElems(ctx context.Context, typ reflect.Type) reflect.Value { docType := d.docType d.docType = typ slice := make([]DocElem, 0, 8) d.readDocWith(func(kind byte, name string) { e := DocElem{Name: name} v := reflect.ValueOf(&e.Value) - if d.readElemTo(v.Elem(), kind) { + if d.readElemTo(ctx, v.Elem(), kind) { slice = append(slice, e) } }) @@ -385,14 +424,14 @@ func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { return slicev } -func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { +func (d *decoder) readRawDocElems(ctx context.Context, typ reflect.Type) reflect.Value { docType := d.docType d.docType = typ slice := make([]RawDocElem, 0, 8) d.readDocWith(func(kind byte, name string) { e := RawDocElem{Name: name} v := reflect.ValueOf(&e.Value) - if d.readElemTo(v.Elem(), kind) { + if d.readElemTo(ctx, v.Elem(), kind) { slice = append(slice, e) } }) @@ -430,14 +469,14 @@ func (d *decoder) readDocWith(f func(kind byte, name string)) { var blackHole = settableValueOf(struct{}{}) -func (d *decoder) dropElem(kind byte) { - d.readElemTo(blackHole, kind) +func (d *decoder) dropElem(ctx context.Context, kind byte) { + d.readElemTo(ctx, blackHole, kind) } // Attempt to decode an element from the document and put it into out. // If the types are not compatible, the returned ok value will be // false and out will be unchanged. -func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { +func (d *decoder) readElemTo(ctx context.Context, out reflect.Value, kind byte) (good bool) { start := d.i @@ -447,25 +486,25 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { outk := out.Kind() switch outk { case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: - d.readDocTo(out) + d.readDocTo(ctx, out) return true } - if setterStyle(outt) != setterNone { - d.readDocTo(out) + if setterStyle(ctx, outt) != setterNone { + d.readDocTo(ctx, out) return true } if outk == reflect.Slice { switch outt.Elem() { case typeDocElem: - out.Set(d.readDocElems(outt)) + out.Set(d.readDocElems(ctx, outt)) case typeRawDocElem: - out.Set(d.readRawDocElems(outt)) + out.Set(d.readRawDocElems(ctx, outt)) default: - d.readDocTo(blackHole) + d.readDocTo(ctx, blackHole) } return true } - d.readDocTo(blackHole) + d.readDocTo(ctx, blackHole) return true } @@ -480,9 +519,9 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { panic("Can't happen. Handled above.") case 0x04: // Array outt := out.Type() - if setterStyle(outt) != setterNone { + if setterStyle(ctx, outt) != setterNone { // Skip the value so its data is handed to the setter below. - d.dropElem(kind) + d.dropElem(ctx, kind) break } for outt.Kind() == reflect.Ptr { @@ -490,12 +529,12 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { } switch outt.Kind() { case reflect.Array: - d.readArrayDocTo(out) + d.readArrayDocTo(ctx, out) return true case reflect.Slice: - in = d.readSliceDoc(outt) + in = d.readSliceDoc(ctx, outt) default: - in = d.readSliceDoc(typeSlice) + in = d.readSliceDoc(ctx, typeSlice) } case 0x05: // Binary b := d.readBinary() @@ -531,7 +570,7 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { case 0x0F: // JavaScript with scope d.i += 4 // Skip length js := JavaScript{d.readStr(), make(M)} - d.readDocTo(reflect.ValueOf(js.Scope)) + d.readDocTo(ctx, reflect.ValueOf(js.Scope)) in = js case 0x10: // Int32 in = int(d.readInt32()) @@ -549,7 +588,9 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { case 0xFF: // Min key in = MinKey default: - panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + var st []byte = make([]byte, 4096) + w := runtime.Stack(st, false) + panic(fmt.Sprintf("Unknown element kind (0x%02X) BT: %s", kind, string(st[:w]))) } outt := out.Type() @@ -559,7 +600,21 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { return true } - if setter := getSetter(outt, out); setter != nil { + if setterCtx := getSetterCtx(ctx, outt, out); setterCtx != nil { + err := setterCtx.SetBSONWithContext(ctx, Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + if setter := getSetter(ctx, outt, out); setter != nil { err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) if err == SetZero { out.Set(reflect.Zero(outt)) diff --git a/bson/encode.go b/bson/encode.go index add39e865..29966018d 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -28,6 +28,7 @@ package bson import ( + "context" "encoding/json" "fmt" "math" @@ -81,15 +82,28 @@ type encoder struct { out []byte } -func (e *encoder) addDoc(v reflect.Value) { +func (e *encoder) addDoc(ctx context.Context, v reflect.Value) { for { - if vi, ok := v.Interface().(Getter); ok { - getv, err := vi.GetBSON() - if err != nil { - panic(err) + //cnt += 1 + if !IsSkipCustom(ctx, v.Type()) { + if vi, ok := v.Interface().(GetterCtx); ok { + getv, err := vi.GetBSONWithContext(ctx) + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + ctx = NewContextWithSkipCustom(ctx, getv) + continue + } + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + ctx = NewContextWithSkipCustom(ctx, getv) + continue } - v = reflect.ValueOf(getv) - continue } if v.Kind() == reflect.Ptr { v = v.Elem() @@ -114,11 +128,11 @@ func (e *encoder) addDoc(v reflect.Value) { switch v.Kind() { case reflect.Map: - e.addMap(v) + e.addMap(ctx, v) case reflect.Struct: - e.addStruct(v) + e.addStruct(ctx, v) case reflect.Array, reflect.Slice: - e.addSlice(v) + e.addSlice(ctx, v) default: panic("Can't marshal " + v.Type().String() + " as a BSON document") } @@ -127,13 +141,13 @@ func (e *encoder) addDoc(v reflect.Value) { e.setInt32(start, int32(len(e.out)-start)) } -func (e *encoder) addMap(v reflect.Value) { +func (e *encoder) addMap(ctx context.Context, v reflect.Value) { for _, k := range v.MapKeys() { - e.addElem(k.String(), v.MapIndex(k), false) + e.addElem(ctx, k.String(), v.MapIndex(k), false) } } -func (e *encoder) addStruct(v reflect.Value) { +func (e *encoder) addStruct(ctx context.Context, v reflect.Value) { sinfo, err := getStructInfo(v.Type()) if err != nil { panic(err) @@ -147,7 +161,7 @@ func (e *encoder) addStruct(v reflect.Value) { if _, found := sinfo.FieldsMap[ks]; found { panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) } - e.addElem(ks, m.MapIndex(k), false) + e.addElem(ctx, ks, m.MapIndex(k), false) } } } @@ -160,7 +174,7 @@ func (e *encoder) addStruct(v reflect.Value) { if info.OmitEmpty && isZero(value) { continue } - e.addElem(info.Key, value, info.MinSize) + e.addElem(ctx, info.Key, value, info.MinSize) } } @@ -200,17 +214,17 @@ func isZero(v reflect.Value) bool { return false } -func (e *encoder) addSlice(v reflect.Value) { +func (e *encoder) addSlice(ctx context.Context, v reflect.Value) { vi := v.Interface() if d, ok := vi.(D); ok { for _, elem := range d { - e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + e.addElem(ctx, elem.Name, reflect.ValueOf(elem.Value), false) } return } if d, ok := vi.(RawD); ok { for _, elem := range d { - e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + e.addElem(ctx, elem.Name, reflect.ValueOf(elem.Value), false) } return } @@ -219,19 +233,19 @@ func (e *encoder) addSlice(v reflect.Value) { if et == typeDocElem { for i := 0; i < l; i++ { elem := v.Index(i).Interface().(DocElem) - e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + e.addElem(ctx, elem.Name, reflect.ValueOf(elem.Value), false) } return } if et == typeRawDocElem { for i := 0; i < l; i++ { elem := v.Index(i).Interface().(RawDocElem) - e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + e.addElem(ctx, elem.Name, reflect.ValueOf(elem.Value), false) } return } for i := 0; i < l; i++ { - e.addElem(itoa(i), v.Index(i), false) + e.addElem(ctx, itoa(i), v.Index(i), false) } } @@ -244,29 +258,40 @@ func (e *encoder) addElemName(kind byte, name string) { e.addBytes(0) } -func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { - +func (e *encoder) addElem(ctx context.Context, name string, v reflect.Value, minSize bool) { if !v.IsValid() { e.addElemName(0x0A, name) return } - if getter, ok := v.Interface().(Getter); ok { - getv, err := getter.GetBSON() - if err != nil { - panic(err) + if !IsSkipCustom(ctx, v.Type()) { + if getter, ok := v.Interface().(GetterCtx); ok { + getv, err := getter.GetBSONWithContext(ctx) + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + e.addElem(NewContextWithSkipCustom(ctx, getv), name, v, minSize) + return + } + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + e.addElem(NewContextWithSkipCustom(ctx, getv), name, v, minSize) + return } - e.addElem(name, reflect.ValueOf(getv), minSize) - return } switch v.Kind() { case reflect.Interface: - e.addElem(name, v.Elem(), minSize) + e.addElem(ctx, name, v.Elem(), minSize) case reflect.Ptr: - e.addElem(name, v.Elem(), minSize) + e.addElem(ctx, name, v.Elem(), minSize) case reflect.String: s := v.String() @@ -348,7 +373,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { case reflect.Map: e.addElemName(0x03, name) - e.addDoc(v) + e.addDoc(ctx, v) case reflect.Slice: vt := v.Type() @@ -358,10 +383,10 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { e.addBinary(0x00, v.Bytes()) } else if et == typeDocElem || et == typeRawDocElem { e.addElemName(0x03, name) - e.addDoc(v) + e.addDoc(ctx, v) } else { e.addElemName(0x04, name) - e.addDoc(v) + e.addDoc(ctx, v) } case reflect.Array: @@ -381,7 +406,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { } } else { e.addElemName(0x04, name) - e.addDoc(v) + e.addDoc(ctx, v) } case reflect.Struct: @@ -429,7 +454,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { e.addElemName(0x0F, name) start := e.reserveInt32() e.addStr(s.Code) - e.addDoc(reflect.ValueOf(s.Scope)) + e.addDoc(ctx, reflect.ValueOf(s.Scope)) e.setInt32(start, int32(len(e.out)-start)) } @@ -447,7 +472,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { default: e.addElemName(0x03, name) - e.addDoc(v) + e.addDoc(ctx, v) } default: diff --git a/bulk_test.go b/bulk_test.go index cb280bbfa..7a5dbecc0 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -317,6 +317,28 @@ func (s *S) TestBulkUpdate(c *C) { c.Assert(res, DeepEquals, []doc{{10}, {20}, {30}}) } +func (s *S) TestBulkUpdateOver1000(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + bulk := coll.Bulk() + for i := 0; i < 1010; i++ { + bulk.Insert(M{"n": i}) + } + _, err = bulk.Run() + c.Assert(err, IsNil) + bulk = coll.Bulk() + for i := 0; i < 1010; i++ { + bulk.Update(M{"n": i}, M{"$set": M{"m": i}}) + } + // if not handle well, mongo will return error here + _, err = bulk.Run() + c.Assert(err, IsNil) +} + func (s *S) TestBulkUpdateError(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -502,3 +524,26 @@ func (s *S) TestBulkRemoveAll(c *C) { c.Assert(err, IsNil) c.Assert(res, DeepEquals, []doc{{3}}) } + +func (s *S) TestBulkDeleteOver1000(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + bulk := coll.Bulk() + for i := 0; i < 1010; i++ { + bulk.Insert(M{"n": i}) + } + _, err = bulk.Run() + c.Assert(err, IsNil) + bulk = coll.Bulk() + for i := 0; i < 1010; i++ { + bulk.Remove(M{"n": i}) + } + // if not handle well, mongo will return error here + _, err = bulk.Run() + c.Assert(err, IsNil) +} + diff --git a/dbtest/dbserver.go b/dbtest/dbserver.go index 16b7b5841..29b93e0a7 100644 --- a/dbtest/dbserver.go +++ b/dbtest/dbserver.go @@ -2,17 +2,30 @@ package dbtest import ( "bytes" + "encoding/hex" "fmt" + "log" + "math/rand" "net" "os" "os/exec" "strconv" + "strings" "time" "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" "gopkg.in/tomb.v2" ) +// Constants to define how the DB test instance should be executed +const ( + // Run MongoDB as local process + LocalProcess = 0 + // Run MongoDB within a docker container + Docker = 1 +) + // DBServer controls a MongoDB server process to be used within test suites. // // The test server is started when Session is called the first time and should @@ -23,12 +36,20 @@ import ( // Before the DBServer is used the SetPath method must be called to define // the location for the database files to be stored. type DBServer struct { - session *mgo.Session - output bytes.Buffer - server *exec.Cmd - dbpath string - host string - tomb tomb.Tomb + session *mgo.Session + output bytes.Buffer + server *exec.Cmd + dbpath string + hostPort string // The IP address and port number of the mgo instance. + hostname string // The IP address or hostname of the container. + version string // The request MongoDB version, when running within a container + eType int // Specify whether mongo should run as a container or regular process + network string // The name of the docker network to which the UT container should be attached + exposePort bool // Specify whether container port should be exposed to the host OS. + debug bool // Log debug statements + containerName string // The container name, when running mgo within a container + tomb tomb.Tomb + rsName string // ReplicaSet Name. If not empty- the mongod will be started as a Replica Set Server } // SetPath defines the path to the directory where the database files will be @@ -38,7 +59,241 @@ func (dbs *DBServer) SetPath(dbpath string) { dbs.dbpath = dbpath } +// SetVersion defines the desired MongoDB version to run within a container. +// The attribute is ignored when running MongoDB outside a container. +func (dbs *DBServer) SetVersion(version string) { + dbs.version = version +} + +func (dbs *DBServer) SetDebug(enableDebug bool) { + dbs.debug = enableDebug +} + +// SetExecType specifies if the DB instance should run locally or as a container. +func (dbs *DBServer) SetExecType(execType int) { + dbs.eType = execType +} + +// SetNetwork sets the name of the docker network to which the UT container should be attached. +func (dbs *DBServer) SetNetwork(network string) { + dbs.network = network +} + +// SetExposePort sets whether the container port should be exposed to the host OS. +func (dbs *DBServer) SetExposePort(exposePort bool) { + dbs.exposePort = exposePort +} + +// SetReplicaSetName sets whether the mongod is started as a replica set +func (dbs *DBServer) SetReplicaSetName(rsName string) { + dbs.rsName = rsName +} + +// SetContainerName sets the name of the docker container when the DB instance is started within a container. +func (dbs *DBServer) SetContainerName(containerName string) { + dbs.containerName = containerName +} + +func (dbs *DBServer) pullDockerImage(dockerImage string) { + // Check if the docker image exists in the local registry. + args := []string{ + "images", + "-q", + dockerImage, + } + cmd := exec.Command("docker", args...) + err := cmd.Run() + if err == nil { + // The image is already present locally. + // Do not invoke docker pull because: + // 1. Every network operations counts towards the dockerhub API rate limiting. + // 2. Reduce the chance of intermittent network issues. + log.Printf("Docker image '%s' is already present in the local registry", dockerImage) + return + } + + // It may take a long time to download the mongo image if the docker image is not installed. + // Execute 'docker pull' now to pull the image before executing it. Otherwise Dial() may fail + // with a timeout after 10 seconds. + args = []string{ + "pull", + dockerImage, + } + start := time.Now() + var stdout, stderr bytes.Buffer + // Seeing intermittent issues such as: + // Error response from daemon: Get https://registry-1.docker.io/v2/: net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headers) + for time.Since(start) < 60*time.Second { + cmd := exec.Command("docker", args...) + log.Printf("Pulling Mongo docker image %s", dockerImage) + if dbs.debug { + cmd.Stdout = &stdout + cmd.Stderr = &stderr + } + err = cmd.Run() + if err == nil { + break + } else { + log.Printf("Failed to pull Mongo container image. err=%s\n%s\n%s", + err.Error(), stdout.String(), stderr.String()) + time.Sleep(5 * time.Second) + } + } + if err != nil { + panic(err) + } + log.Printf("Pulled Mongo docker image") +} + +// Start Mongo DB within Docker container on host. +// It assumes Docker is already installed. +func (dbs *DBServer) execContainer(network string, exposePort bool) *exec.Cmd { + if dbs.version == "" { + dbs.version = "latest" + } + + dockerImage := fmt.Sprintf("mongo:%s", dbs.version) + dbs.pullDockerImage(dockerImage) + + args := []string{ + "run", + "-t", + "--rm", // Automatically remove the container when it exits + } + if network != "" { + args = append(args, []string{ + "--net", + network, + }...) + } + if exposePort { + args = append(args, []string{"-p", fmt.Sprintf("%d:%d", 27017, 27017)}...) + } + args = append(args, []string{ + "--name", + dbs.containerName, + fmt.Sprintf("mongo:%s", dbs.version), + }...) + + if dbs.rsName != "" { + args = append(args, []string{ + "mongod", + "--replSet", + dbs.rsName, + }...) + } + log.Printf("DB start up arguments are: %v", args) + return exec.Command("docker", args...) +} + +// Returns the host name of the Mongo test instance. +// If the test instance runs as a container, it returns the IP address of the container. +// If the test instance runs in the host, returns the host name. +func (dbs *DBServer) GetHostName() string { + if dbs.eType == Docker { + return dbs.hostname + } else { + if hostname, err := os.Hostname(); err != nil { + return hostname + } else { + return "127.0.0.1" + } + } +} + +// GetContainerName returns the name of the container, when running the Mongo UT instance in a container. +func (dbs *DBServer) GetContainerName() string { + return dbs.containerName +} + +// GetContainerIpAddr returns the IP address of the test Mongo instance +// The client should connect directly on the docker bridge network (such as when the client is also +// running in a container), then client should connect to port 27017. +func (dbs *DBServer) GetContainerIpAddr() (string, error) { + start := time.Now() + var err error + var stderr bytes.Buffer + for time.Since(start) < 60*time.Second { + if dbs.server.ProcessState != nil { + // The process has exited + log.Printf("Mongo container has exited unexpectedly. Output:\n%s", dbs.output.String()) + return "", fmt.Errorf("Process has exited") + } + stderr.Reset() + args := []string{"inspect", "-f", "'{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}'", dbs.containerName} + cmd := exec.Command("docker", args...) // #nosec + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + err = cmd.Run() + if err != nil { + // This could be because the container has not started yet. Retry later + log.Printf("Failed to get container IP address. Will retry later. Err: %v", err) + time.Sleep(3 * time.Second) + continue + } + ipAddr := strings.Trim(strings.TrimSpace(out.String()), "'") + dbs.hostname = ipAddr + log.Printf("Mongo IP address is %v", dbs.hostname) + if dbs.network == "" { + return "127.0.0.1", nil + } else { + return dbs.hostname, err + } + } + log.Printf("Unable to get container IP address: %v", err) + return "", fmt.Errorf("Failed to run command. error=%s, stderr=%s\n", err.Error(), stderr.String()) +} + +// Stop the docker container running Mongo. +func (dbs *DBServer) stopContainer() { + args := []string{ + "stop", + dbs.containerName, + } + cmd := exec.Command("docker", args...) + if dbs.debug { + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + } + err := cmd.Run() + if err != nil { + panic(err) + } + // Remove the container and its unamed volume. + // In some cases the "docker run --rm" option does not remove the container. + args = []string{ + "rm", + "-v", + dbs.containerName, + } + cmd = exec.Command("docker", args...) + if dbs.debug { + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + } + err = cmd.Run() +} + +// Start Mongo DB as process on host. It assumes Mongo is already installed +func (dbs *DBServer) execLocal(port int) *exec.Cmd { + args := []string{ + "--dbpath", dbs.dbpath, + "--bind_ip", "127.0.0.1", + "--port", strconv.Itoa(port), + } + + if dbs.rsName != "" { + args = append(args, []string{ + "--replSet", + dbs.rsName, + }...) + } + return exec.Command("mongod", args...) +} + func (dbs *DBServer) start() { + log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Llongfile) if dbs.server != nil { panic("DBServer already started") } @@ -52,41 +307,99 @@ func (dbs *DBServer) start() { } addr := l.Addr().(*net.TCPAddr) l.Close() - dbs.host = addr.String() - args := []string{ - "--dbpath", dbs.dbpath, - "--bind_ip", "127.0.0.1", - "--port", strconv.Itoa(addr.Port), - "--nssize", "1", - "--noprealloc", - "--smallfiles", - "--nojournal", + if dbs.containerName == "" { + // Generate a name for the container. This will help to inspect the container + // and get the Mongo PID. + u := make([]byte, 8) + // The default number generator is deterministic. + s := rand.NewSource(time.Now().UnixNano()) + r := rand.New(s) + _, err = r.Read(u) + if err != nil { + panic(err) + } + dbs.containerName = fmt.Sprintf("mongo-%s", hex.EncodeToString(u)) } + dbs.tomb = tomb.Tomb{} - dbs.server = exec.Command("mongod", args...) + switch dbs.eType { + case LocalProcess: + dbs.hostPort = addr.String() + dbs.server = dbs.execLocal(addr.Port) + case Docker: + dbs.server = dbs.execContainer(dbs.network, dbs.exposePort) + default: + panic(fmt.Sprintf("unsupported exec type: %d", dbs.eType)) + } dbs.server.Stdout = &dbs.output dbs.server.Stderr = &dbs.output + log.Printf("Starting Mongo instance: %v. Address: %s. Network: '%s'", dbs.server.Args, dbs.hostPort, dbs.network) err = dbs.server.Start() if err != nil { - panic(err) + panic("Failed to start Mongo instance: " + err.Error()) + } + log.Printf("Mongo instance started") + go func() { + // Call Wait() so cmd.ProcessState is set after command has completed. + err = dbs.server.Wait() + if err != nil { + log.Printf("Command exited. Output:\n%s", dbs.output.String()) + } + }() + if dbs.eType == Docker { + ipAddr, err2 := dbs.GetContainerIpAddr() + if err2 != nil { + panic(err2) + } + dbs.hostPort = fmt.Sprintf("%s:%d", ipAddr, 27017) } dbs.tomb.Go(dbs.monitor) dbs.Wipe() } +func (dbs DBServer) printMongoDebugInfo() { + fmt.Fprintf(os.Stderr, "[%s] mongod processes running right now:\n", time.Now().String()) + cmd := exec.Command("/bin/sh", "-c", "ps auxw | grep mongod") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Run() + if dbs.eType == Docker { + fmt.Fprintf(os.Stderr, "[%s] mongod containers running right now:\n", time.Now().String()) + cmd := exec.Command("/bin/sh", "-c", "docker ps -a |grep mongo") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Run() + + args := []string{ + "inspect", + dbs.containerName, + } + cmd = exec.Command("docker", args...) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + fmt.Fprintf(os.Stderr, "[%s] Container inspect:\n", time.Now().String()) + cmd.Run() + args = []string{ + "logs", + dbs.containerName, + } + cmd = exec.Command("docker", args...) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + fmt.Fprintf(os.Stderr, "[%s] Container logs:\n", time.Now().String()) + cmd.Run() + } + fmt.Fprintf(os.Stderr, "----------------------------------------\n") +} + func (dbs *DBServer) monitor() error { dbs.server.Process.Wait() if dbs.tomb.Alive() { // Present some debugging information. fmt.Fprintf(os.Stderr, "---- mongod process died unexpectedly:\n") fmt.Fprintf(os.Stderr, "%s", dbs.output.Bytes()) - fmt.Fprintf(os.Stderr, "---- mongod processes running right now:\n") - cmd := exec.Command("/bin/sh", "-c", "ps auxw | grep mongod") - cmd.Stdout = os.Stderr - cmd.Stderr = os.Stderr - cmd.Run() - fmt.Fprintf(os.Stderr, "----------------------------------------\n") + dbs.printMongoDebugInfo() panic("mongod process died unexpectedly") } @@ -111,6 +424,10 @@ func (dbs *DBServer) Stop() { } if dbs.server != nil { dbs.tomb.Kill(nil) + if dbs.eType == Docker { + // Invoke 'docker stop' + dbs.stopContainer() + } dbs.server.Process.Signal(os.Interrupt) select { case <-dbs.tomb.Dead(): @@ -125,21 +442,53 @@ func (dbs *DBServer) Stop() { // must be closed after the test is done with it. // // The first Session obtained from a DBServer will start it. -func (dbs *DBServer) Session() *mgo.Session { +func (dbs *DBServer) SessionWithTimeout(timeout time.Duration) *mgo.Session { if dbs.server == nil { dbs.start() } if dbs.session == nil { mgo.ResetStats() var err error - dbs.session, err = mgo.Dial(dbs.host + "/test") + log.Printf("Dialing mongod located at '%s'. timeout: %v", dbs.hostPort, timeout) + // If The MongoDB Driver is Configured with ReplicaSet - (Then configure Replica Set first!) + // Directly Connect First - and then configure Replica Set! + if dbs.rsName != "" { + dbs.session, err = mgo.DialWithTimeout(dbs.hostPort+"/test?connect=direct", timeout) + if err != nil { + log.Printf("Unable to dial mongod located at '%s'. Timeout=%v, Error: %s", dbs.hostPort, timeout, err.Error()) + log.Printf("%s", dbs.output.Bytes()) + dbs.printMongoDebugInfo() + panic(err) + } + err = dbs.Initiate() + if err != nil { + log.Printf("Unable to Configure Replica Set '%s'. Timeout=%v, Error: %s", dbs.rsName, timeout, err.Error()) + log.Printf("%s", dbs.output.Bytes()) + dbs.printMongoDebugInfo() + panic(err) + } + dbs.session.Close() // Create a new One Below without Direct=true... + + } + dbs.session, err = mgo.DialWithTimeout(dbs.hostPort+"/test", timeout) if err != nil { + log.Printf("Unable to dial mongod located at '%s'. Timeout=%v, Error: %s", dbs.hostPort, timeout, err.Error()) + log.Printf("%s", dbs.output.Bytes()) + dbs.printMongoDebugInfo() panic(err) } } return dbs.session.Copy() } +// Session returns a new session to the server. The returned session +// must be closed after the test is done with it. +// +// The first Session obtained from a DBServer will start it. +func (dbs *DBServer) Session() *mgo.Session { + return dbs.SessionWithTimeout(10 * time.Second) +} + // checkSessions ensures all mgo sessions opened were properly closed. // For slightly faster tests, it may be disabled setting the // environmnet variable CHECK_SESSIONS to 0. @@ -156,7 +505,9 @@ func (dbs *DBServer) checkSessions() { } time.Sleep(100 * time.Millisecond) } - panic("There are mgo sessions still alive.") + stats := mgo.GetStats() + panic(fmt.Sprintf("There are mgo sessions still alive. SocketCount=%d InUseCount=%d", + stats.SocketsAlive, stats.SocketsInUse)) } // Wipe drops all created databases and their data. @@ -169,6 +520,7 @@ func (dbs *DBServer) checkSessions() { // there is a session leak. func (dbs *DBServer) Wipe() { if dbs.server == nil || dbs.session == nil { + log.Printf("Skip Wipe()") return } dbs.checkSessions() @@ -187,6 +539,7 @@ func (dbs *DBServer) Wipe() { switch name { case "admin", "local", "config": default: + log.Printf("Drop database '%s'", name) err = session.DB(name).DropDatabase() if err != nil { panic(err) @@ -194,3 +547,176 @@ func (dbs *DBServer) Wipe() { } } } + +// Code From https://github.com/juju/replicaset To Start MongoDB in replica Set Configuration +// Config is the document stored in mongodb that defines the servers in the +// replica set +type Config struct { + Name string `bson:"_id"` + Version int `bson:"version"` + ProtocolVersion int `bson:"protocolVersion,omitempty"` + Members []Member `bson:"members"` +} + +// Member holds configuration information for a replica set member. +// +// See http://docs.mongodb.org/manual/reference/replica-configuration/ +// for more details +type Member struct { + // Id is a unique id for a member in a set. + Id int `bson:"_id"` + + // Address holds the network address of the member, + // in the form hostname:port. + Address string `bson:"host"` +} + +func (dbs *DBServer) Initiate() error { + monotonicSession := dbs.session.Clone() + defer monotonicSession.Close() + monotonicSession.SetMode(mgo.Monotonic, true) + protocolVersion := 1 + var err error + // We don't know mongod's ability to use a correct IPv6 addr format + // until the server is started, but we need to know before we can start + // it. Try the older, incorrect format, if the correct format fails. + cfg := []Config{ + { + Name: dbs.rsName, + Version: 1, + ProtocolVersion: protocolVersion, + Members: []Member{{ + Id: 0, + Address: dbs.hostPort, + }}, + }, + } + + // Attempt replSetInitiate, with potential retries. + for i := 0; i < 5; i++ { + monotonicSession.Refresh() + if err = doAttemptInitiate(monotonicSession, cfg); err != nil { + time.Sleep(100 * time.Millisecond) + continue + } + break + } + + // Wait for replSetInitiate to complete. Even if err != nil, + // it may be that replSetInitiate is still in progress, so + // attempt CurrentStatus. + for i := 0; i < 10; i++ { + monotonicSession.Refresh() + var status *Status + status, err = getCurrentStatus(monotonicSession) + if err != nil { + log.Printf("Initiate: fetching replication status failed: %v", err) + } + if err != nil || len(status.Members) == 0 { + time.Sleep(500 * time.Millisecond) + continue + } + break + } + return err +} + +// CurrentStatus returns the status of the replica set for the given session. +func getCurrentStatus(session *mgo.Session) (*Status, error) { + status := &Status{} + err := session.Run("replSetGetStatus", status) + if err != nil { + return nil, fmt.Errorf("cannot get replica set status: %v", err) + } + + for index, member := range status.Members { + status.Members[index].Address = member.Address + } + return status, nil +} + +// Status holds data about the status of members of the replica set returned +// from replSetGetStatus +// +// See http://docs.mongodb.org/manual/reference/command/replSetGetStatus/#dbcmd.replSetGetStatus +type Status struct { + Name string `bson:"set"` + Members []MemberStatus `bson:"members"` +} + +// Status holds the status of a replica set member returned from +// replSetGetStatus. +type MemberStatus struct { + // Id holds the replica set id of the member that the status is describing. + Id int `bson:"_id"` + + // Address holds address of the member that the status is describing. + Address string `bson:"name"` + + // Self holds whether this is the status for the member that + // the session is connected to. + Self bool `bson:"self"` + + // ErrMsg holds the most recent error or status message received + // from the member. + ErrMsg string `bson:"errmsg"` + + // Healthy reports whether the member is up. It is true for the + // member that the request was made to. + Healthy bool `bson:"health"` + + // State describes the current state of the member. + State MemberState `bson:"state"` +} + +// doAttemptInitiate will attempt to initiate a mongodb replicaset with each of +// the given configs, returning as soon as one config is successful. +func doAttemptInitiate(monotonicSession *mgo.Session, cfg []Config) error { + var err error + for _, c := range cfg { + if err = monotonicSession.Run(bson.D{{"replSetInitiate", c}}, nil); err != nil { + log.Printf("Unsuccessful attempt to initiate replicaset: %v", err) + continue + } + return nil + } + return err +} + +type MemberState int + +const ( + StartupState = iota + PrimaryState + SecondaryState + RecoveringState + FatalState + Startup2State + UnknownState + ArbiterState + DownState + RollbackState + ShunnedState +) + +var memberStateStrings = []string{ + StartupState: "STARTUP", + PrimaryState: "PRIMARY", + SecondaryState: "SECONDARY", + RecoveringState: "RECOVERING", + FatalState: "FATAL", + Startup2State: "STARTUP2", + UnknownState: "UNKNOWN", + ArbiterState: "ARBITER", + DownState: "DOWN", + RollbackState: "ROLLBACK", + ShunnedState: "SHUNNED", +} + +// String returns a string describing the state. +func (state MemberState) String() string { + if state < 0 || int(state) >= len(memberStateStrings) { + return "INVALID_MEMBER_STATE" + } + return memberStateStrings[state] +} diff --git a/dbtest/dbserver_test.go b/dbtest/dbserver_test.go index 79812fde3..67d4b186f 100644 --- a/dbtest/dbserver_test.go +++ b/dbtest/dbserver_test.go @@ -32,6 +32,23 @@ func (s *S) TearDownTest(c *C) { os.Setenv("CHECK_SESSIONS", s.oldCheckSessions) } +func (s *S) TestRunAsDocker(c *C) { + var server dbtest.DBServer + server.SetPath(c.MkDir()) + server.SetVersion("3.4") + server.SetExecType(dbtest.Docker) + defer server.Stop() + + session := server.Session() + err := session.DB("mydb").C("mycoll").Insert(M{"a": 1}) + buildInfo, err := session.BuildInfo() + c.Assert(err, IsNil) + c.Assert(buildInfo.VersionAtLeast(3, 4), Equals, true) + + session.Close() + c.Assert(err, IsNil) +} + func (s *S) TestWipeData(c *C) { var server dbtest.DBServer server.SetPath(c.MkDir()) diff --git a/session.go b/session.go index 3dccf364e..0220debc9 100644 --- a/session.go +++ b/session.go @@ -1011,7 +1011,7 @@ type indexSpec struct { DefaultLanguage string "default_language,omitempty" LanguageOverride string "language_override,omitempty" TextIndexVersion int "textIndexVersion,omitempty" - + PartialFilterExpression bson.M "partialFilterExpression,omitempty" Collation *Collation "collation,omitempty" } @@ -1021,6 +1021,7 @@ type Index struct { DropDups bool // Drop documents with the same index key as a previously indexed one Background bool // Build index in background and return immediately Sparse bool // Only index documents containing the Key fields + PartialFilterExpression bson.M //If specified, the index only references documents that match the filter expression // If ExpireAfter is defined the server will periodically delete // documents with indexed time.Time older than the provided delta. @@ -1284,6 +1285,7 @@ func (c *Collection) EnsureIndex(index Index) error { DropDups: index.DropDups, Background: index.Background, Sparse: index.Sparse, + PartialFilterExpression: index.PartialFilterExpression, Bits: index.Bits, Min: index.Minf, Max: index.Maxf, @@ -1494,20 +1496,21 @@ func (c *Collection) Indexes() (indexes []Index, err error) { func indexFromSpec(spec indexSpec) Index { index := Index{ - Name: spec.Name, - Key: simpleIndexKey(spec.Key), - Unique: spec.Unique, - DropDups: spec.DropDups, - Background: spec.Background, - Sparse: spec.Sparse, - Minf: spec.Min, - Maxf: spec.Max, - Bits: spec.Bits, - BucketSize: spec.BucketSize, - DefaultLanguage: spec.DefaultLanguage, - LanguageOverride: spec.LanguageOverride, - ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, - Collation: spec.Collation, + Name: spec.Name, + Key: simpleIndexKey(spec.Key), + Unique: spec.Unique, + DropDups: spec.DropDups, + Background: spec.Background, + Sparse: spec.Sparse, + Minf: spec.Min, + Maxf: spec.Max, + Bits: spec.Bits, + BucketSize: spec.BucketSize, + DefaultLanguage: spec.DefaultLanguage, + LanguageOverride: spec.LanguageOverride, + ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, + Collation: spec.Collation, + PartialFilterExpression: spec.PartialFilterExpression, } if float64(int(spec.Min)) == spec.Min && float64(int(spec.Max)) == spec.Max { index.Min = int(spec.Min) @@ -2172,6 +2175,8 @@ type Pipe struct { pipeline interface{} allowDisk bool batchSize int + maxTimeMS int64 + collation *Collation } type pipeCmd struct { @@ -2180,6 +2185,8 @@ type pipeCmd struct { Cursor *pipeCmdCursor ",omitempty" Explain bool ",omitempty" AllowDisk bool "allowDiskUse,omitempty" + MaxTimeMS int64 `bson:"maxTimeMS,omitempty"` + Collation *Collation `bson:"collation,omitempty"` } type pipeCmdCursor struct { @@ -2233,6 +2240,10 @@ func (p *Pipe) Iter() *Iter { Pipeline: p.pipeline, AllowDisk: p.allowDisk, Cursor: &pipeCmdCursor{p.batchSize}, + Collation: p.collation, + } + if p.maxTimeMS > 0 { + cmd.MaxTimeMS = p.maxTimeMS } err := c.Database.Run(cmd, &result) if e, ok := err.(*QueryError); ok && e.Message == `unrecognized field "cursor` { @@ -2371,6 +2382,30 @@ func (p *Pipe) Batch(n int) *Pipe { return p } +// SetMaxTime sets the maximum amount of time to allow the query to run. +// +func (p *Pipe) SetMaxTime(d time.Duration) *Pipe { + p.maxTimeMS = int64(d / time.Millisecond) + return p +} + + +// Collation allows to specify language-specific rules for string comparison, +// such as rules for lettercase and accent marks. +// When specifying collation, the locale field is mandatory; all other collation +// fields are optional +// +// Relevant documentation: +// +// https://docs.mongodb.com/manual/reference/collation/ +// +func (p *Pipe) Collation(collation *Collation) *Pipe { + if collation != nil { + p.collation = collation + } + return p +} + // mgo.v3: Use a single user-visible error type. type LastError struct { @@ -2858,6 +2893,38 @@ func (q *Query) Sort(fields ...string) *Query { return q } +// Collation allows to specify language-specific rules for string comparison, +// such as rules for lettercase and accent marks. +// When specifying collation, the locale field is mandatory; all other collation +// fields are optional +// +// For example, to perform a case and diacritic insensitive query: +// +// var res []bson.M +// collation := &mgo.Collation{Locale: "en", Strength: 1} +// err = db.C("mycoll").Find(bson.M{"a": "a"}).Collation(collation).All(&res) +// if err != nil { +// return err +// } +// +// This query will match following documents: +// +// {"a": "a"} +// {"a": "A"} +// {"a": "รข"} +// +// Relevant documentation: +// +// https://docs.mongodb.com/manual/reference/collation/ +// +func (q *Query) Collation(collation *Collation) *Query { + q.m.Lock() + q.op.options.Collation = collation + q.op.hasOptions = true + q.m.Unlock() + return q +} + // Explain returns a number of details about how the MongoDB server would // execute the requested query, such as the number of objects examined, // the number of times the read lock was yielded to allow writes to go in, @@ -3155,6 +3222,7 @@ func prepareFindOp(socket *mongoSocket, op *queryOp, limit int32) bool { Sort: op.options.OrderBy, Skip: op.skip, Limit: limit, + Collation: op.options.Collation, MaxTimeMS: op.options.MaxTimeMS, MaxScan: op.options.MaxScan, Hint: op.options.Hint, @@ -3222,6 +3290,7 @@ type findCmd struct { OplogReplay bool `bson:"oplogReplay,omitempty"` NoCursorTimeout bool `bson:"noCursorTimeout,omitempty"` AllowPartialResults bool `bson:"allowPartialResults,omitempty"` + Collation *Collation `bson:"collation,omitempty"` } // getMoreCmd holds the command used for requesting more query results on MongoDB 3.2+. @@ -4613,6 +4682,7 @@ func (c *Collection) writeOp(op interface{}, ordered bool) (lerr *LastError, err if socket.ServerInfo().MaxWireVersion >= 2 { // Servers with a more recent write protocol benefit from write commands. + inputOp := op if op, ok := op.(*insertOp); ok && len(op.documents) > 1000 { var lerr LastError @@ -4641,6 +4711,54 @@ func (c *Collection) writeOp(op interface{}, ordered bool) (lerr *LastError, err return &lerr, lerr.ecases[0].Err } return &lerr, nil + } else if updateOps, ok := inputOp.(bulkUpdateOp); ok && len(updateOps) > 1000 { + var lerr LastError + // Maximum batch size is 1000. Must split out in separate operations for compatibility. + all := updateOps + for i := 0; i < len(all); i += 1000 { + l := i + 1000 + if l > len(all) { + l = len(all) + } + updateOps = all[i:l] + oplerr, err := c.writeOpCommand(socket, safeOp, updateOps, ordered, bypassValidation) + lerr.N += oplerr.N + lerr.modified += oplerr.modified + if err != nil { + for ei := range oplerr.ecases { + oplerr.ecases[ei].Index += i + } + lerr.ecases = append(lerr.ecases, oplerr.ecases...) + } + } + if len(lerr.ecases) != 0 { + return &lerr, lerr.ecases[0].Err + } + return &lerr, nil + } else if deleteOps, ok := inputOp.(bulkDeleteOp); ok && len(deleteOps) > 1000 { + var lerr LastError + // Maximum batch size is 1000. Must split out in separate operations for compatibility. + all := deleteOps + for i := 0; i < len(all); i += 1000 { + l := i + 1000 + if l > len(all) { + l = len(all) + } + deleteOps = all[i:l] + oplerr, err := c.writeOpCommand(socket, safeOp, deleteOps, ordered, bypassValidation) + lerr.N += oplerr.N + lerr.modified += oplerr.modified + if err != nil { + for ei := range oplerr.ecases { + oplerr.ecases[ei].Index += i + } + lerr.ecases = append(lerr.ecases, oplerr.ecases...) + } + } + if len(lerr.ecases) != 0 { + return &lerr, lerr.ecases[0].Err + } + return &lerr, nil } return c.writeOpCommand(socket, safeOp, op, ordered, bypassValidation) } else if updateOps, ok := op.(bulkUpdateOp); ok { diff --git a/session_test.go b/session_test.go index 492f21078..9c34a59b9 100644 --- a/session_test.go +++ b/session_test.go @@ -2975,6 +2975,21 @@ var indexTests = []struct { "key": M{"cn": 1}, "ns": "mydb.mycoll", }, +}, { + mgo.Index{ + Key: []string{"p"}, + Unique: true, + PartialFilterExpression: bson.M{ + "p": bson.M{"$exists": true}, + }, + }, + M{ + "name": "p_1", + "key": M{"p": 1}, + "ns": "mydb.mycoll", + "unique": true, + "partialFilterExpression": M{"p" : M{"$exists" : true}}, + }, }} func (s *S) TestEnsureIndex(c *C) { diff --git a/socket.go b/socket.go index 8891dd5d7..13ca361f0 100644 --- a/socket.go +++ b/socket.go @@ -91,6 +91,7 @@ type queryWrapper struct { MaxScan int "$maxScan,omitempty" MaxTimeMS int "$maxTimeMS,omitempty" Comment string "$comment,omitempty" + Collation *Collation `bson:"$collation,omitempty"` } func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {