|
22 | 22 | */ |
23 | 23 |
|
24 | 24 | #include "lang/eval_internal.h" |
| 25 | +#include "core/pool.h" |
| 26 | +#include "mem/heap.h" |
| 27 | +#include "mem/cow.h" |
25 | 28 |
|
26 | 29 | /* ══════════════════════════════════════════ |
27 | | - * Arithmetic builtins |
| 30 | + * Typed vector arithmetic — rayforce1 pattern |
| 31 | + * |
| 32 | + * Each operation dispatches on MTYPE2(left_type, right_type) once, |
| 33 | + * then runs a tight typed-pointer loop that the compiler vectorizes. |
| 34 | + * Output buffer reuses input when rc==1 and types match. |
28 | 35 | * ══════════════════════════════════════════ */ |
29 | 36 |
|
| 37 | +#define MTYPE2(a, b) (((int)(a) + 128) * 256 + ((int)(b) + 128)) |
| 38 | + |
| 39 | +/* Typed loop macros: atom+vec, vec+atom, vec+vec. |
| 40 | + * lt/rt/ot = element C types, OP = binary expression macro. */ |
| 41 | +#define LOOP_A_V(lval, rptr, optr, n, OP) \ |
| 42 | + for (int64_t _i = 0; _i < (n); _i++) \ |
| 43 | + (optr)[_i] = OP((lval), (rptr)[_i]); |
| 44 | + |
| 45 | +#define LOOP_V_A(lptr, rval, optr, n, OP) \ |
| 46 | + for (int64_t _i = 0; _i < (n); _i++) \ |
| 47 | + (optr)[_i] = OP((lptr)[_i], (rval)); |
| 48 | + |
| 49 | +#define LOOP_V_V(lptr, rptr, optr, n, OP) \ |
| 50 | + for (int64_t _i = 0; _i < (n); _i++) \ |
| 51 | + (optr)[_i] = OP((lptr)[_i], (rptr)[_i]); |
| 52 | + |
| 53 | +/* Op macros — expand to typed expressions the compiler can vectorize */ |
| 54 | +#define OP_ADD_I64(a,b) ((int64_t)((uint64_t)(a)+(uint64_t)(b))) |
| 55 | +#define OP_SUB_I64(a,b) ((int64_t)((uint64_t)(a)-(uint64_t)(b))) |
| 56 | +#define OP_MUL_I64(a,b) ((int64_t)((uint64_t)(a)*(uint64_t)(b))) |
| 57 | +#define OP_ADD_I32(a,b) ((int32_t)((uint32_t)(a)+(uint32_t)(b))) |
| 58 | +#define OP_SUB_I32(a,b) ((int32_t)((uint32_t)(a)-(uint32_t)(b))) |
| 59 | +#define OP_MUL_I32(a,b) ((int32_t)((uint32_t)(a)*(uint32_t)(b))) |
| 60 | +#define OP_ADD_F64(a,b) ((a)+(b)) |
| 61 | +#define OP_SUB_F64(a,b) ((a)-(b)) |
| 62 | +#define OP_MUL_F64(a,b) ((a)*(b)) |
| 63 | +#define OP_EQ_I64(a,b) ((uint8_t)((a)==(b))) |
| 64 | +#define OP_NE_I64(a,b) ((uint8_t)((a)!=(b))) |
| 65 | +#define OP_LT_I64(a,b) ((uint8_t)((a)<(b))) |
| 66 | +#define OP_LE_I64(a,b) ((uint8_t)((a)<=(b))) |
| 67 | +#define OP_GT_I64(a,b) ((uint8_t)((a)>(b))) |
| 68 | +#define OP_GE_I64(a,b) ((uint8_t)((a)>=(b))) |
| 69 | +#define OP_EQ_F64(a,b) ((uint8_t)((a)==(b))) |
| 70 | +#define OP_NE_F64(a,b) ((uint8_t)((a)!=(b))) |
| 71 | +#define OP_LT_F64(a,b) ((uint8_t)((a)<(b))) |
| 72 | +#define OP_LE_F64(a,b) ((uint8_t)((a)<=(b))) |
| 73 | +#define OP_GT_F64(a,b) ((uint8_t)((a)>(b))) |
| 74 | +#define OP_GE_F64(a,b) ((uint8_t)((a)>=(b))) |
| 75 | +/* MIN2/MAX2 */ |
| 76 | +#define OP_MIN2_I64(a,b) ((a)<(b)?(a):(b)) |
| 77 | +#define OP_MAX2_I64(a,b) ((a)>(b)?(a):(b)) |
| 78 | +#define OP_MIN2_F64(a,b) ((a)<(b)?(a):(b)) |
| 79 | +#define OP_MAX2_F64(a,b) ((a)>(b)?(a):(b)) |
| 80 | + |
| 81 | +/* Context for parallel typed dispatch */ |
| 82 | +typedef struct { |
| 83 | + ray_t* left; |
| 84 | + ray_t* right; |
| 85 | + ray_t* out; |
| 86 | + uint16_t opcode; |
| 87 | +} binop_vec_ctx_t; |
| 88 | + |
| 89 | +/* Emit typed loop for a single type width. |
| 90 | + * T=C type, W=width tag, SV_EXPR=scalar read expression */ |
| 91 | +#define TYPED_LOOP(T, lptr, rptr, lsv, rsv, optr, n, xv, yv, OPNAME) \ |
| 92 | + do { \ |
| 93 | + if (xv && yv) LOOP_V_V(lptr, rptr, optr, n, OPNAME) \ |
| 94 | + else if (xv) LOOP_V_A(lptr, rsv, optr, n, OPNAME) \ |
| 95 | + else LOOP_A_V(lsv, rptr, optr, n, OPNAME) \ |
| 96 | + } while(0) |
| 97 | + |
| 98 | +/* Parallel worker: typed dispatch per chunk */ |
| 99 | +static void binop_vec_worker(void* ctx_, uint32_t wid, int64_t start, int64_t end) { |
| 100 | + (void)wid; |
| 101 | + binop_vec_ctx_t* c = (binop_vec_ctx_t*)ctx_; |
| 102 | + int64_t n = end - start; |
| 103 | + ray_t* x = c->left; |
| 104 | + ray_t* y = c->right; |
| 105 | + ray_t* out = c->out; |
| 106 | + int8_t ot = out->type; |
| 107 | + bool xv = ray_is_vec(x), yv = ray_is_vec(y); |
| 108 | + uint16_t opc = c->opcode; |
| 109 | + |
| 110 | + /* For arithmetic: output type matches input type. |
| 111 | + * For comparisons: inputs are I64/I32/F64, output is BOOL (U8). */ |
| 112 | + bool is_cmp = (opc >= OP_EQ && opc <= OP_GE); |
| 113 | + |
| 114 | + /* Resolve data pointers once */ |
| 115 | + if (is_cmp) { |
| 116 | + /* Comparison: read inputs at their type, write BOOL output */ |
| 117 | + uint8_t* restrict od = (uint8_t*)ray_data(out) + start; |
| 118 | + /* Determine input type from the vector operand */ |
| 119 | + int8_t it = xv ? x->type : yv ? y->type : RAY_I64; |
| 120 | + if (it == RAY_I64 || it == RAY_TIMESTAMP) { |
| 121 | + int64_t* ld = xv ? (int64_t*)ray_data(x) + start : NULL; |
| 122 | + int64_t* rd = yv ? (int64_t*)ray_data(y) + start : NULL; |
| 123 | + int64_t lsv = xv ? 0 : x->i64; |
| 124 | + int64_t rsv = yv ? 0 : y->i64; |
| 125 | + switch (opc) { |
| 126 | + case OP_EQ: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_EQ_I64); break; |
| 127 | + case OP_NE: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_NE_I64); break; |
| 128 | + case OP_LT: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_LT_I64); break; |
| 129 | + case OP_LE: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_LE_I64); break; |
| 130 | + case OP_GT: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_GT_I64); break; |
| 131 | + case OP_GE: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_GE_I64); break; |
| 132 | + default: break; |
| 133 | + } |
| 134 | + } else if (it == RAY_I32 || it == RAY_DATE || it == RAY_TIME) { |
| 135 | + int32_t* ld = xv ? (int32_t*)ray_data(x) + start : NULL; |
| 136 | + int32_t* rd = yv ? (int32_t*)ray_data(y) + start : NULL; |
| 137 | + int32_t lsv = xv ? 0 : (int32_t)x->i32; |
| 138 | + int32_t rsv = yv ? 0 : (int32_t)y->i32; |
| 139 | + switch (opc) { |
| 140 | + case OP_EQ: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_EQ_I64); break; |
| 141 | + case OP_NE: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_NE_I64); break; |
| 142 | + case OP_LT: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_LT_I64); break; |
| 143 | + case OP_LE: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_LE_I64); break; |
| 144 | + case OP_GT: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_GT_I64); break; |
| 145 | + case OP_GE: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_GE_I64); break; |
| 146 | + default: break; |
| 147 | + } |
| 148 | + } else if (it == RAY_F64) { |
| 149 | + double* ld = xv ? (double*)ray_data(x) + start : NULL; |
| 150 | + double* rd = yv ? (double*)ray_data(y) + start : NULL; |
| 151 | + double lsv = xv ? 0 : x->f64; |
| 152 | + double rsv = yv ? 0 : y->f64; |
| 153 | + switch (opc) { |
| 154 | + case OP_EQ: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_EQ_F64); break; |
| 155 | + case OP_NE: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_NE_F64); break; |
| 156 | + case OP_LT: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_LT_F64); break; |
| 157 | + case OP_LE: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_LE_F64); break; |
| 158 | + case OP_GT: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_GT_F64); break; |
| 159 | + case OP_GE: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_GE_F64); break; |
| 160 | + default: break; |
| 161 | + } |
| 162 | + } |
| 163 | + } else if (ot == RAY_I64 || ot == RAY_TIMESTAMP) { |
| 164 | + int64_t* restrict od = (int64_t*)ray_data(out) + start; |
| 165 | + int64_t* ld = xv ? (int64_t*)ray_data(x) + start : NULL; |
| 166 | + int64_t* rd = yv ? (int64_t*)ray_data(y) + start : NULL; |
| 167 | + int64_t lsv = xv ? 0 : x->i64, rsv = yv ? 0 : y->i64; |
| 168 | + switch (opc) { |
| 169 | + case OP_ADD: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_ADD_I64); break; |
| 170 | + case OP_SUB: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_SUB_I64); break; |
| 171 | + case OP_MUL: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_MUL_I64); break; |
| 172 | + case OP_MIN2: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_MIN2_I64); break; |
| 173 | + case OP_MAX2: TYPED_LOOP(int64_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_MAX2_I64); break; |
| 174 | + default: break; |
| 175 | + } |
| 176 | + } else if (ot == RAY_I32 || ot == RAY_DATE || ot == RAY_TIME) { |
| 177 | + int32_t* restrict od = (int32_t*)ray_data(out) + start; |
| 178 | + int32_t* ld = xv ? (int32_t*)ray_data(x) + start : NULL; |
| 179 | + int32_t* rd = yv ? (int32_t*)ray_data(y) + start : NULL; |
| 180 | + int32_t lsv = xv ? 0 : (int32_t)x->i32, rsv = yv ? 0 : (int32_t)y->i32; |
| 181 | + switch (opc) { |
| 182 | + case OP_ADD: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_ADD_I32); break; |
| 183 | + case OP_SUB: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_SUB_I32); break; |
| 184 | + case OP_MUL: TYPED_LOOP(int32_t,ld,rd,lsv,rsv,od,n,xv,yv,OP_MUL_I32); break; |
| 185 | + default: break; |
| 186 | + } |
| 187 | + } else if (ot == RAY_F64) { |
| 188 | + double* restrict od = (double*)ray_data(out) + start; |
| 189 | + double* ld = xv ? (double*)ray_data(x) + start : NULL; |
| 190 | + double* rd = yv ? (double*)ray_data(y) + start : NULL; |
| 191 | + double lsv = xv ? 0 : x->f64, rsv = yv ? 0 : y->f64; |
| 192 | + switch (opc) { |
| 193 | + case OP_ADD: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_ADD_F64); break; |
| 194 | + case OP_SUB: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_SUB_F64); break; |
| 195 | + case OP_MUL: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_MUL_F64); break; |
| 196 | + case OP_MIN2: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_MIN2_F64); break; |
| 197 | + case OP_MAX2: TYPED_LOOP(double,ld,rd,lsv,rsv,od,n,xv,yv,OP_MAX2_F64); break; |
| 198 | + default: break; |
| 199 | + } |
| 200 | + } |
| 201 | + #undef TYPED_LOOP |
| 202 | +} |
| 203 | + |
| 204 | +/* Infer output type for arithmetic on two operands */ |
| 205 | +static int8_t infer_arith_type(ray_t* x, ray_t* y) { |
| 206 | + int8_t xt = ray_is_atom(x) ? -(x->type) : x->type; |
| 207 | + int8_t yt = ray_is_atom(y) ? -(y->type) : y->type; |
| 208 | + if (xt == RAY_F64 || yt == RAY_F64) return RAY_F64; |
| 209 | + if (xt == RAY_I64 || yt == RAY_I64) return RAY_I64; |
| 210 | + if (xt == RAY_TIMESTAMP || yt == RAY_TIMESTAMP) return RAY_TIMESTAMP; |
| 211 | + if (xt == RAY_I32 || yt == RAY_I32) return RAY_I32; |
| 212 | + if (xt == RAY_DATE || yt == RAY_DATE) return RAY_DATE; |
| 213 | + if (xt == RAY_TIME || yt == RAY_TIME) return RAY_TIME; |
| 214 | + if (xt == RAY_I16 || yt == RAY_I16) return RAY_I16; |
| 215 | + return RAY_I64; |
| 216 | +} |
| 217 | + |
| 218 | +/* Fast vector binary op: typed dispatch + rc==1 reuse + parallel. |
| 219 | + * Returns NULL when the fast path doesn't apply (caller falls through). |
| 220 | + * opcode is the DAG opcode (OP_ADD, OP_SUB, OP_MUL, OP_EQ, OP_LT, etc.) */ |
| 221 | +ray_t* binop_vec(ray_t* x, ray_t* y, uint16_t opcode) { |
| 222 | + bool xv = ray_is_vec(x), yv = ray_is_vec(y); |
| 223 | + bool xa = ray_is_atom(x), ya = ray_is_atom(y); |
| 224 | + if (!(xv || xa) || !(yv || ya) || (!xv && !yv)) return NULL; |
| 225 | + |
| 226 | + /* Skip when nulls present — generic path handles null propagation */ |
| 227 | + if (xv && (x->attrs & RAY_ATTR_HAS_NULLS)) return NULL; |
| 228 | + if (yv && (y->attrs & RAY_ATTR_HAS_NULLS)) return NULL; |
| 229 | + if (xa && RAY_ATOM_IS_NULL(x)) return NULL; |
| 230 | + if (ya && RAY_ATOM_IS_NULL(y)) return NULL; |
| 231 | + |
| 232 | + int64_t len = xv ? x->len : y->len; |
| 233 | + if (xv && yv && x->len != y->len) return NULL; |
| 234 | + |
| 235 | + /* Determine input type — both operands must match */ |
| 236 | + int8_t xt = xv ? x->type : -(x->type); |
| 237 | + int8_t yt = yv ? y->type : -(y->type); |
| 238 | + /* Skip temporal types — output depends on op (DATE-DATE→I32 etc.) */ |
| 239 | + bool x_temporal = (xt == RAY_DATE || xt == RAY_TIME || xt == RAY_TIMESTAMP); |
| 240 | + bool y_temporal = (yt == RAY_DATE || yt == RAY_TIME || yt == RAY_TIMESTAMP); |
| 241 | + if (x_temporal || y_temporal) return NULL; |
| 242 | + /* Both operands must have the same type */ |
| 243 | + if (xt != yt) return NULL; |
| 244 | + /* Only handle I64, I32, F64 */ |
| 245 | + if (xt != RAY_I64 && xt != RAY_I32 && xt != RAY_F64) return NULL; |
| 246 | + /* Only handle opcodes with typed loop implementations */ |
| 247 | + if (opcode != OP_ADD && opcode != OP_SUB && opcode != OP_MUL && |
| 248 | + opcode != OP_MIN2 && opcode != OP_MAX2 && |
| 249 | + opcode != OP_EQ && opcode != OP_NE && opcode != OP_LT && |
| 250 | + opcode != OP_LE && opcode != OP_GT && opcode != OP_GE) |
| 251 | + return NULL; |
| 252 | + |
| 253 | + /* Output type: BOOL for comparisons, same as input for arithmetic */ |
| 254 | + bool is_cmp = (opcode >= OP_EQ && opcode <= OP_GE); |
| 255 | + int8_t ot = is_cmp ? RAY_BOOL : xt; |
| 256 | + |
| 257 | + /* rc==1 buffer reuse (arithmetic only, not slices — slices alias |
| 258 | + * their parent's buffer so writing into them corrupts the parent) */ |
| 259 | + ray_t* out; |
| 260 | + if (!is_cmp && xv && x->rc == 1 && x->type == ot && |
| 261 | + !(x->attrs & RAY_ATTR_SLICE)) { |
| 262 | + out = x; |
| 263 | + ray_retain(out); |
| 264 | + } else if (!is_cmp && yv && y->rc == 1 && y->type == ot && |
| 265 | + !(y->attrs & RAY_ATTR_SLICE)) { |
| 266 | + out = y; |
| 267 | + ray_retain(out); |
| 268 | + } else { |
| 269 | + out = ray_vec_new(ot, len); |
| 270 | + } |
| 271 | + if (!out || RAY_IS_ERR(out)) return out; |
| 272 | + out->len = len; |
| 273 | + |
| 274 | + binop_vec_ctx_t ctx = { .left = x, .right = y, .out = out, .opcode = opcode }; |
| 275 | + ray_pool_t* pool = ray_pool_get(); |
| 276 | + if (pool && len >= RAY_PARALLEL_THRESHOLD) |
| 277 | + ray_pool_dispatch(pool, binop_vec_worker, &ctx, len); |
| 278 | + else |
| 279 | + binop_vec_worker(&ctx, 0, 0, len); |
| 280 | + |
| 281 | + return out; |
| 282 | +} |
| 283 | + |
30 | 284 | /* Binary arithmetic */ |
31 | 285 | ray_t* ray_add_fn(ray_t* a, ray_t* b) { |
| 286 | + |
32 | 287 | /* Temporal + integer arithmetic (only int types, not float) */ |
33 | 288 | if (is_temporal(a) && is_numeric(b) && b->type != -RAY_F64) { |
34 | 289 | if (RAY_ATOM_IS_NULL(a) || RAY_ATOM_IS_NULL(b)) |
@@ -91,6 +346,7 @@ ray_t* ray_add_fn(ray_t* a, ray_t* b) { |
91 | 346 | } |
92 | 347 |
|
93 | 348 | ray_t* ray_sub_fn(ray_t* a, ray_t* b) { |
| 349 | + |
94 | 350 | /* Temporal - int null propagation (both operands) */ |
95 | 351 | if (is_temporal(a) && is_numeric(b)) { |
96 | 352 | if (RAY_ATOM_IS_NULL(a) || RAY_ATOM_IS_NULL(b)) |
@@ -160,6 +416,7 @@ ray_t* ray_sub_fn(ray_t* a, ray_t* b) { |
160 | 416 | } |
161 | 417 |
|
162 | 418 | ray_t* ray_mul_fn(ray_t* a, ray_t* b) { |
| 419 | + |
163 | 420 | /* int * TIME → TIME, TIME * int → TIME */ |
164 | 421 | if (is_numeric(a) && b->type == -RAY_TIME) { |
165 | 422 | if (RAY_ATOM_IS_NULL(a) || RAY_ATOM_IS_NULL(b)) return ray_typed_null(-RAY_TIME); |
|
0 commit comments