@@ -1054,4 +1054,91 @@ array turbo_encode_v(const array& values, StreamOrDevice s_) {
10541054 return array (buf.data (), out_shape, uint8);
10551055}
10561056
1057+
1058+ // ── TurboQuant Decode ─────────────────────────────────────────────────────────
1059+ // Batch-decode packed uint8 compressed history back to float32 tensors.
1060+ // Used by KVCacheSimple when routing compressed history through standard SDPA.
1061+ // Supports head_dim=128 (record=68B K / 50B V) and head_dim=256 (2 sub-groups).
1062+
1063+ array turbo_decode_k (const array& packed, StreamOrDevice s_) {
1064+ auto s = to_stream (s_);
1065+
1066+ const int record_bytes = static_cast <int >(packed.shape (-1 ));
1067+ if (record_bytes != TURBO_K_RECORD && record_bytes != TURBO_K_RECORD * 2 ) {
1068+ throw std::invalid_argument (
1069+ " [turbo_decode_k] last dim must be 68 (D=128) or 136 (D=256), got " +
1070+ std::to_string (record_bytes));
1071+ }
1072+ const int n_subgroups = record_bytes / TURBO_K_RECORD ;
1073+ const int head_dim = n_subgroups * ::mlx::core::fast::TURBO_D ;
1074+
1075+ // Materialise packed buffer on CPU
1076+ auto packed_u8 = astype (packed, uint8, s);
1077+ eval (packed_u8);
1078+ const uint8_t * src = packed_u8.data <uint8_t >();
1079+
1080+ const int N = static_cast <int >(packed_u8.size () / record_bytes);
1081+ std::vector<float > buf (static_cast <size_t >(N) * head_dim);
1082+
1083+ for (int i = 0 ; i < N; ++i) {
1084+ for (int g = 0 ; g < n_subgroups; ++g) {
1085+ const uint8_t * sub_src = src + i * record_bytes + g * TURBO_K_RECORD ;
1086+ ::mlx::core::fast::TurboQuantK rec;
1087+ std::memset (&rec, 0 , sizeof (rec));
1088+ std::memcpy (rec.indices , sub_src, 48 );
1089+ std::memcpy (rec.qjl_signs , sub_src + 48 , 16 );
1090+ std::memcpy (&rec.norm_fp16 , sub_src + 64 , 2 );
1091+ std::memcpy (&rec.rnorm_fp16 , sub_src + 66 , 2 );
1092+ ::mlx::core::fast::turbo_dequantize_k (
1093+ rec,
1094+ buf.data() + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1095+ ::mlx::core::fast::TURBO_D);
1096+ }
1097+ }
1098+
1099+ Shape out_shape = packed.shape ();
1100+ out_shape.back () = head_dim;
1101+ // Return float32; Swift caller casts to model dtype (fp16/bf16) as needed
1102+ return array (buf.data (), out_shape, float32);
1103+ }
1104+
1105+ array turbo_decode_v (const array& packed, StreamOrDevice s_) {
1106+ auto s = to_stream (s_);
1107+
1108+ const int record_bytes = static_cast <int >(packed.shape (-1 ));
1109+ if (record_bytes != TURBO_V_RECORD && record_bytes != TURBO_V_RECORD * 2 ) {
1110+ throw std::invalid_argument (
1111+ " [turbo_decode_v] last dim must be 50 (D=128) or 100 (D=256), got " +
1112+ std::to_string (record_bytes));
1113+ }
1114+ const int n_subgroups = record_bytes / TURBO_V_RECORD ;
1115+ const int head_dim = n_subgroups * ::mlx::core::fast::TURBO_D ;
1116+
1117+ auto packed_u8 = astype (packed, uint8, s);
1118+ eval (packed_u8);
1119+ const uint8_t * src = packed_u8.data <uint8_t >();
1120+
1121+ const int N = static_cast <int >(packed_u8.size () / record_bytes);
1122+ std::vector<float > buf (static_cast <size_t >(N) * head_dim);
1123+
1124+ for (int i = 0 ; i < N; ++i) {
1125+ for (int g = 0 ; g < n_subgroups; ++g) {
1126+ const uint8_t * sub_src = src + i * record_bytes + g * TURBO_V_RECORD ;
1127+ ::mlx::core::fast::TurboQuantV rec;
1128+ std::memset (&rec, 0 , sizeof (rec));
1129+ std::memcpy (rec.indices , sub_src, 48 );
1130+ std::memcpy (&rec.norm_fp16 , sub_src + 48 , 2 );
1131+ ::mlx::core::fast::turbo_dequantize_v (
1132+ rec,
1133+ buf.data() + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1134+ ::mlx::core::fast::TURBO_D);
1135+ }
1136+ }
1137+
1138+ Shape out_shape = packed.shape ();
1139+ out_shape.back () = head_dim;
1140+ return array (buf.data (), out_shape, float32);
1141+ }
1142+
10571143} // namespace mlx::core::fast
1144+
0 commit comments