Skip to content

Commit

Permalink
Merge pull request #9 from briefercloud/render-filter-values
Browse files Browse the repository at this point in the history
feat: Introduce ability to use variables in filters
  • Loading branch information
vieiralucas authored Sep 11, 2024
2 parents 2cd17d7 + d569585 commit 8b7d0a9
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 175 deletions.
208 changes: 133 additions & 75 deletions apps/api/src/python/visualizations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,29 @@ function getCode(
const code = `import json
import altair as alt
import pandas as pd
from jinja2 import Template
axisTitlePadding = 10
def _briefer_render_filter_value(filter):
try:
if isinstance(filter["value"], list):
value = list(map(lambda x: Template(x).render(**globals()), filter["value"]))
else:
value = Template(filter["value"]).render(**globals())
return value
except Exception as e:
filter["renderError"] = {
"type": "error",
"ename": e.__class__.__name__,
"evalue": str(e),
"traceback": []
}
print(json.dumps({ "type": "filter-result", "filter": filter }))
return None
def _briefer_convert_to_utc_safe(datetime_series, comparison_value):
# Localize timezone-naive datetimes to UTC
if datetime_series.dt.tz is None:
Expand Down Expand Up @@ -525,68 +545,79 @@ def _briefer_create_visualization(
df.loc[:, x_axis] = pd.to_datetime(df[x_axis])
for filter in filtering:
column_name = filter['column']['name']
operator = filter['operator']
value = filter['value']
if pd.api.types.is_numeric_dtype(df[column_name]):
if operator == 'eq':
df = df[df[column_name] == value]
elif operator == 'ne':
df = df[df[column_name] != value]
elif operator == 'lt':
df = df[df[column_name] < value]
elif operator == 'lte':
df = df[df[column_name] <= value]
elif operator == 'gt':
df = df[df[column_name] > value]
elif operator == 'gte':
df = df[df[column_name] >= value]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
elif pd.api.types.is_string_dtype(df[column_name]) or pd.api.types.is_categorical_dtype(df[column_name]) or pd.api.types.is_object_dtype(df[column_name]):
if operator == 'eq':
df = df[df[column_name] == value]
elif operator == 'ne':
df = df[df[column_name] != value]
elif operator == 'contains':
df = df[df[column_name].str.contains(value)]
elif operator == 'notContains':
df = df[~df[column_name].str.contains(value)]
elif operator == 'startsWith':
df = df[df[column_name].str.startswith(value)]
elif operator == 'endsWith':
df = df[df[column_name].str.endswith(value)]
elif operator == 'in':
df = df[df[column_name].isin(value)]
elif operator == 'notIn':
df = df[~df[column_name].isin(value)]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
elif pd.api.types.is_datetime64_any_dtype(df[column_name]):
# Convert both DataFrame column and value to UTC safely
df_column_utc, value_utc = _briefer_convert_to_utc_safe(df[column_name], pd.to_datetime(value))
# Perform comparisons using the safely converted UTC values
if operator == 'eq':
df = df[df_column_utc == value_utc]
elif operator == 'ne':
df = df[df_column_utc != value_utc]
elif operator == 'before':
df = df[df_column_utc < value_utc]
elif operator == 'beforeOrEq':
df = df[df_column_utc <= value_utc]
elif operator == 'after':
df = df[df_column_utc > value_utc]
elif operator == 'afterOrEq':
df = df[df_column_utc >= value_utc]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
column_name = filter['column']['name']
operator = filter['operator']
value = _briefer_render_filter_value(filter)
# if the value is None, rendering failed, skip this filter
if value == None:
continue
if filter["value"] != value:
filter["renderedValue"] = value
print(json.dumps({"type": "filter-result", "filter": filter}))
if pd.api.types.is_numeric_dtype(df[column_name]):
value = pd.to_numeric(value, errors='coerce')
if operator == 'eq':
df = df[df[column_name] == value]
elif operator == 'ne':
df = df[df[column_name] != value]
elif operator == 'lt':
df = df[df[column_name] < value]
elif operator == 'lte':
df = df[df[column_name] <= value]
elif operator == 'gt':
df = df[df[column_name] > value]
elif operator == 'gte':
df = df[df[column_name] >= value]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
elif pd.api.types.is_string_dtype(df[column_name]) or pd.api.types.is_categorical_dtype(df[column_name]) or pd.api.types.is_object_dtype(df[column_name]):
if operator == 'eq':
df = df[df[column_name] == value]
elif operator == 'ne':
df = df[df[column_name] != value]
elif operator == 'contains':
df = df[df[column_name].str.contains(value)]
elif operator == 'notContains':
df = df[~df[column_name].str.contains(value)]
elif operator == 'startsWith':
df = df[df[column_name].str.startswith(value)]
elif operator == 'endsWith':
df = df[df[column_name].str.endswith(value)]
elif operator == 'in':
df = df[df[column_name].isin(value)]
elif operator == 'notIn':
df = df[~df[column_name].isin(value)]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
elif pd.api.types.is_datetime64_any_dtype(df[column_name]):
# Convert both DataFrame column and value to UTC safely
df_column_utc, value_utc = _briefer_convert_to_utc_safe(df[column_name], pd.to_datetime(value))
# Perform comparisons using the safely converted UTC values
if operator == 'eq':
df = df[df_column_utc == value_utc]
elif operator == 'ne':
df = df[df_column_utc != value_utc]
elif operator == 'before':
df = df[df_column_utc < value_utc]
elif operator == 'beforeOrEq':
df = df[df_column_utc <= value_utc]
elif operator == 'after':
df = df[df_column_utc > value_utc]
elif operator == 'afterOrEq':
df = df[df_column_utc >= value_utc]
elif operator == 'isNull':
df = df[df[column_name].isnull()]
elif operator == 'isNotNull':
df = df[df[column_name].notnull()]
def _briefer_create_histogram(
df,
Expand Down Expand Up @@ -731,7 +762,7 @@ def _briefer_create_visualization(
# if no y or y is not numeric, return
is_y_numeric = series["column"]["type"] == "Q"
if not series['column'] or not is_y_numeric:
print(json.dumps({"success": False, "reason": "invalid-params"}))
print(json.dumps({"type": "result", "success": False, "reason": "invalid-params"}))
return
chart, capped = _briefer_create_chart(
Expand Down Expand Up @@ -773,6 +804,12 @@ def _briefer_create_visualization(
if not ct:
ct = chart_type
color = colors[i % len(colors)]
color_by = serie.get('colorBy', None)
# if df is empty, color_by is not valid, ignore it
if len(df) == 0:
color_by = None
chart, capped = _briefer_create_chart(
df.copy(),
ct,
Expand All @@ -785,8 +822,8 @@ def _briefer_create_visualization(
serie['column']['name'],
serie['column']['type'],
serie['aggregateFunction'],
serie['colorBy']['name'] if serie['colorBy'] else None,
serie['colorBy']['type'] if serie['colorBy'] else None,
color_by["name"] if color_by else None,
color_by["type"] if color_by else None,
number_values_format,
show_data_labels,
len(y_axis['series']) == 1,
Expand All @@ -807,14 +844,14 @@ def _briefer_create_visualization(
vis = alt.layer(*layers, usermeta=usermeta).resolve_scale(y='independent').configure_view(stroke=None).configure_range(category={"scheme": "tableau20"})
# return spec as json
print(json.dumps({"success": True, "spec": vis.to_json(default=str)}, default=str))
print(json.dumps({"type": "result", "success": True, "spec": vis.to_json(default=str)}, default=str))
if not "${dataframe.name}" in globals():
try:
import pandas as pd
${dataframe.name} = pd.read_parquet("/home/jupyteruser/.briefer/query-${
dataframe.id
}.parquet.gzip")
dataframe.id
}.parquet.gzip")
except:
pass
Expand All @@ -835,17 +872,19 @@ if "${dataframe.name}" in globals():
json.loads(${JSON.stringify(JSON.stringify(filtering))})
)
else:
print(json.dumps({"success": False, "reason": "dataframe-not-found"}))`
print(json.dumps({"type": "result", "success": False, "reason": "dataframe-not-found"}))`

return code
}

const CreateVisualizationPythonResult = z.union([
z.object({
type: z.literal('result'),
success: z.literal(true),
spec: jsonString.pipe(JsonObject),
}),
z.object({
type: z.literal('result'),
success: z.literal(false),
reason: z.union([
z.literal('dataframe-not-found'),
Expand All @@ -858,15 +897,23 @@ type CreateVisualizationPythonResult = z.infer<
typeof CreateVisualizationPythonResult
>

const FilterResult = z.object({
type: z.literal('filter-result'),
filter: VisualizationFilter,
})
type FilterResult = z.infer<typeof FilterResult>

export type CreateVisualizationResult = {
promise: Promise<
| {
success: true
spec: JsonObject
filterResults: Record<string, VisualizationFilter>
}
| {
success: false
reason: 'dataframe-not-found' | 'aborted' | 'invalid-params'
filterResults: Record<string, VisualizationFilter>
}
>
abort: () => Promise<void>
Expand Down Expand Up @@ -938,14 +985,16 @@ export async function createVisualization(
{ storeHistory: false }
)

const filterResults: Record<string, VisualizationFilter> = {}

const promise = execute.then(
async (): CreateVisualizationResult['promise'] => {
let result: CreateVisualizationPythonResult | null = null
const errors: Error[] = []
for (const output of outputs) {
if (output.type === 'error') {
if (output.ename === 'KeyboardInterrupt') {
result = { success: false, reason: 'aborted' }
result = { type: 'result', success: false, reason: 'aborted' }
break
}

Expand All @@ -966,10 +1015,19 @@ export async function createVisualization(
for (const line of output.text.split('\n')) {
try {
const asJson = JSON.parse(line)
const parsed = CreateVisualizationPythonResult.safeParse(asJson)
const parsed =
CreateVisualizationPythonResult.or(FilterResult).safeParse(
asJson
)
if (parsed.success) {
result = parsed.data
break
if (parsed.data.type === 'result') {
if (parsed.data.success) {
result = parsed.data
break
}
} else {
filterResults[parsed.data.filter.id] = parsed.data.filter
}
}
} catch (err) {
errors.push(err as Error)
Expand All @@ -993,10 +1051,10 @@ export async function createVisualization(
}

if (!result.success) {
return result
return { ...result, filterResults }
}

return { success: true, spec: result.spec }
return { success: true, spec: result.spec, filterResults }
}
)

Expand Down
3 changes: 3 additions & 0 deletions apps/api/src/yjs/v2/executors/blocks/visualization.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ describe('VisualizationExecutor', () => {
promise: Promise.resolve({
success: true,
spec,
filterResults: {},
}),
abort: jest.fn(),
})
Expand Down Expand Up @@ -147,6 +148,7 @@ describe('VisualizationExecutor', () => {
promise: Promise.resolve({
success: false,
reason: 'dataframe-not-found',
filterResults: {},
}),
abort: jest.fn(),
})
Expand All @@ -163,6 +165,7 @@ describe('VisualizationExecutor', () => {
promise: Promise.resolve({
success: false,
reason: 'aborted',
filterResults: {},
}),
abort: jest.fn(),
})
Expand Down
12 changes: 12 additions & 0 deletions apps/api/src/yjs/v2/executors/blocks/visualization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ export class VisualizationExecutor implements IVisualizationExecutor {
}

const result = await promise
if (Object.keys(result.filterResults).length > 0) {
const nextFilters = filters.map((f) => {
const next = result.filterResults[f.id]
if (next) {
return next
}

return f
})
block.setAttribute('filters', nextFilters)
}

if (!result.success) {
if (result.reason !== 'aborted') {
block.setAttribute('error', result.reason)
Expand Down
Loading

0 comments on commit 8b7d0a9

Please sign in to comment.