// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // License. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Package deep provides a deep equality check for use in tests. package deep // import "github.com/influxdata/influxdb/pkg/deep" import ( "fmt" "math" "reflect" ) // Equal is a copy of reflect.DeepEqual except that it treats NaN == NaN as true. func Equal(a1, a2 interface{}) bool { if a1 == nil || a2 == nil { return a1 == a2 } v1 := reflect.ValueOf(a1) v2 := reflect.ValueOf(a2) if v1.Type() != v2.Type() { return false } return deepValueEqual(v1, v2, make(map[visit]bool), 0) } // Tests for deep equality using reflected types. The map argument tracks // comparisons that have already been seen, which allows short circuiting on // recursive types. func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { if !v1.IsValid() || !v2.IsValid() { return v1.IsValid() == v2.IsValid() } if v1.Type() != v2.Type() { return false } // if depth > 10 { panic("deepValueEqual") } // for debugging hard := func(k reflect.Kind) bool { switch k { case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: return true } return false } if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { addr1 := v1.UnsafeAddr() addr2 := v2.UnsafeAddr() if addr1 > addr2 { // Canonicalize order to reduce number of entries in visited. addr1, addr2 = addr2, addr1 } // Short circuit if references are identical ... if addr1 == addr2 { return true } // ... or already seen typ := v1.Type() v := visit{addr1, addr2, typ} if visited[v] { return true } // Remember for later. visited[v] = true } switch v1.Kind() { case reflect.Array: for i := 0; i < v1.Len(); i++ { if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } } return true case reflect.Slice: if v1.IsNil() != v2.IsNil() { return false } if v1.Len() != v2.Len() { return false } if v1.Pointer() == v2.Pointer() { return true } for i := 0; i < v1.Len(); i++ { if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } } return true case reflect.Interface: if v1.IsNil() || v2.IsNil() { return v1.IsNil() == v2.IsNil() } return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) case reflect.Ptr: return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) case reflect.Struct: for i, n := 0, v1.NumField(); i < n; i++ { if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { return false } } return true case reflect.Map: if v1.IsNil() != v2.IsNil() { return false } if v1.Len() != v2.Len() { return false } if v1.Pointer() == v2.Pointer() { return true } for _, k := range v1.MapKeys() { if !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { return false } } return true case reflect.Func: if v1.IsNil() && v2.IsNil() { return true } // Can't do better than this: return false case reflect.String: return v1.String() == v2.String() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return v1.Int() == v2.Int() case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return v1.Uint() == v2.Uint() case reflect.Float32, reflect.Float64: // Special handling for floats so that NaN == NaN is true. f1, f2 := v1.Float(), v2.Float() if math.IsNaN(f1) && math.IsNaN(f2) { return true } return f1 == f2 case reflect.Bool: return v1.Bool() == v2.Bool() default: panic(fmt.Sprintf("cannot compare type: %s", v1.Kind().String())) } } // During deepValueEqual, must keep track of checks that are // in progress. The comparison algorithm assumes that all // checks in progress are true when it reencounters them. // Visited comparisons are stored in a map indexed by visit. type visit struct { a1 uintptr a2 uintptr typ reflect.Type }