Skip to content

Commit b8c2fbc

Browse files
committed
New cols selector for query_topic
1 parent 622f373 commit b8c2fbc

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
44
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
55

6-
## [0.8.5]
6+
## [0.8.6]
7+
8+
- New `cols` column selector option for`query_<topic>`
9+
10+
## [0.8.5] 2026-03-24
711

812
### Added
913

src/rembus/db.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def init_db(router, schema):
232232
db.execute(sql)
233233
# Create the query, delete and upsert rpc topics
234234
router.handler[f"upsert_{tname}"] = partial(
235-
rpc_upsert, router, tname)
235+
rpc_upsert, router, tname
236+
)
236237
router.handler[f"query_{tname}"] = partial(query, router, tname)
237238
router.handler[f"delete_{tname}"] = partial(delete, router, tname)
238239

@@ -250,20 +251,17 @@ def rpc_add_ts(table, obj):
250251
for el in obj:
251252
el[col_name] = ts
252253
elif isinstance(obj, pl.DataFrame):
253-
obj = obj.with_columns(
254-
pl.lit(ts).alias(col_name)
255-
)
254+
obj = obj.with_columns(pl.lit(ts).alias(col_name))
256255

257256
return obj
258257

259258

260-
def rpc_upsert(router, tname, obj, ctx=None, node=None):
259+
def rpc_upsert(router, tname, obj, options={}, ctx=None, node=None):
261260
"""Insert/update obj values."""
262261
table = router.tables[tname]
263262
col_names = columns(table) + list(table.extras.values())
264263
indexes = list(table.keys)
265264
con = router.db
266-
267265
obj = rpc_add_ts(table, obj)
268266
batch_df = None
269267
if isinstance(obj, dict):
@@ -281,7 +279,7 @@ def rpc_upsert(router, tname, obj, ctx=None, node=None):
281279
raise RuntimeError(f"upsert failed: invalid record type {type(obj)}")
282280

283281
if table.keys:
284-
execute_upsert_df(con, table, col_names, indexes, batch_df)
282+
execute_upsert_df(con, table, col_names, indexes, batch_df, options)
285283
else:
286284
con.register("batch_view", batch_df)
287285
con.execute(f"INSERT INTO {tname} SELECT * FROM batch_view")
@@ -314,13 +312,15 @@ def query(router, table, obj=None, ctx=None, node=None):
314312
if obj is None:
315313
sql = f"SELECT * FROM {table}"
316314
else:
317-
allowed = ("where", "when")
315+
allowed = ("cols", "where", "when")
318316
bad = [k for k in obj.keys() if k not in allowed]
319317
if bad:
320318
raise ValueError(f"invalid keys: {', '.join(bad)}")
321319
where_cond = ""
320+
322321
if "where" in obj:
323322
where_cond = " WHERE " + obj["where"]
323+
324324
at = ""
325325
if "when" in obj:
326326
if isinstance(obj["when"], (int, float)):
@@ -329,7 +329,16 @@ def query(router, table, obj=None, ctx=None, node=None):
329329
else:
330330
ts = obj["when"]
331331
at = f" AT (TIMESTAMP => CAST('{ts}' AS TIMESTAMP))"
332-
sql = f"SELECT * FROM {table} {at} {where_cond}"
332+
333+
if "cols" in obj:
334+
cols = obj["cols"]
335+
if isinstance(cols, list):
336+
cols_str = ", ".join(cols)
337+
else:
338+
cols_str = str(cols)
339+
sql = f"SELECT {cols_str} FROM {table} {at} {where_cond}"
340+
else:
341+
sql = f"SELECT * FROM {table} {at} {where_cond}"
333342

334343
logger.debug("db query: %s", sql)
335344
return router.db.execute(sql).pl()
@@ -576,7 +585,7 @@ def handle_default(msg, table, col_names, records, tname):
576585
records.append(vals + extra_vals)
577586

578587

579-
def execute_upsert_df(con, table, col_names, indexes, df):
588+
def execute_upsert_df(con, table, col_names, indexes, df, options={}):
580589
tname = table.table
581590
if df.is_empty():
582591
return
@@ -590,7 +599,15 @@ def execute_upsert_df(con, table, col_names, indexes, df):
590599
col_list = ", ".join(col_names)
591600
val_list = ", ".join(f"df_view.{c}" for c in col_names)
592601

593-
update_cols = [c for c in col_names if c not in indexes]
602+
if "nulls" in options and options["nulls"]:
603+
update_cols = [c for c in col_names if c not in indexes]
604+
else:
605+
update_cols = [
606+
c
607+
for c in col_names
608+
if c not in indexes and df[c].drop_nulls().len() > 0
609+
]
610+
594611
update_list = ", ".join(f"{c} = df_view.{c}" for c in update_cols)
595612

596613
sql = f"""
@@ -600,7 +617,6 @@ def execute_upsert_df(con, table, col_names, indexes, df):
600617
WHEN MATCHED THEN UPDATE SET {update_list}
601618
WHEN NOT MATCHED THEN INSERT ({col_list}) VALUES ({val_list})
602619
"""
603-
604620
con.execute(sql)
605621
con.unregister("df_view")
606622

0 commit comments

Comments
 (0)