diff --git a/Source/MLX/DType.swift b/Source/MLX/DType.swift index 8912d3ca..3aa0534c 100644 --- a/Source/MLX/DType.swift +++ b/Source/MLX/DType.swift @@ -147,9 +147,9 @@ public enum DType: Hashable, Sendable, CaseIterable { #else case .float16: 0x1p-10 // 2^-10 #endif - case .float32: Double(Float.ulpOfOne) + case .float32, .complex64: Double(Float.ulpOfOne) case .bfloat16: 0x1p-7 // 2^-7 (7 mantissa bits) - case .complex64, .float64: Double.ulpOfOne + case .float64: Double.ulpOfOne default: fatalError("\(dtype) is not a floating point type") } } @@ -166,8 +166,8 @@ public enum DType: Hashable, Sendable, CaseIterable { case .float16: 0x1.FFCp15 // 65504 #endif case .bfloat16: Double(Float(bitPattern: 0x7F7F_0000)) // bf16 = high 16 bits of f32 - case .float32: Double(Float.greatestFiniteMagnitude) - case .complex64, .float64: Double.greatestFiniteMagnitude + case .float32, .complex64: Double(Float.greatestFiniteMagnitude) + case .float64: Double.greatestFiniteMagnitude default: fatalError("\(dtype) is not a floating point type") } } @@ -180,9 +180,9 @@ public enum DType: Hashable, Sendable, CaseIterable { #else case .float16: 0x1p-14 // 2^-14 #endif - case .float32: Double(Float.leastNormalMagnitude) + case .float32, .complex64: Double(Float.leastNormalMagnitude) case .bfloat16: 0x1p-126 // 2^-126 (same exponent range as f32) - case .complex64, .float64: Double.leastNormalMagnitude + case .float64: Double.leastNormalMagnitude default: fatalError("\(dtype) is not a floating point type") } } @@ -195,9 +195,9 @@ public enum DType: Hashable, Sendable, CaseIterable { #else case .float16: 0x1p-24 // 2^-24 #endif - case .float32: Double(Float.leastNonzeroMagnitude) + case .float32, .complex64: Double(Float.leastNonzeroMagnitude) case .bfloat16: 0x1p-133 // 2^-133 = 2^-126 ยท 2^-7 - case .complex64, .float64: Double.leastNonzeroMagnitude + case .float64: Double.leastNonzeroMagnitude default: fatalError("\(dtype) is not a floating point type") } }