Skip to content

Commit 3cec099

Browse files
committed
fixes for geom_bar fill
1 parent 298ade8 commit 3cec099

File tree

8 files changed

+27
-12
lines changed

8 files changed

+27
-12
lines changed

plots/geom_bar_fill.png

1.13 KB
Loading

src/python_ggplot/gg/drawing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,9 @@ def draw_sub_df(
916916
for i in range(len(styles) - 1):
917917
style = merge_user_style(styles[i], fg)
918918
poly_line = init_poly_line_from_points(
919-
view, [line_points[i].point(), line_points[i + 1].point()], deepcopy(style)
919+
view,
920+
[line_points[i].point(), line_points[i + 1].point()],
921+
deepcopy(style),
920922
)
921923
view.add_obj(poly_line)
922924
elif geom_type == GeomType.RASTER:

src/python_ggplot/gg/geom/utils.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -324,14 +324,19 @@ def _filled_count_geom_map(
324324
from python_ggplot.gg.styles.utils import apply_style
325325

326326
grouped = df.groupby(filled_stat_geom.map_discrete_columns, sort=False) # type: ignore
327+
sorted_keys = sorted(grouped.groups.keys(), reverse=True) # type: ignore
327328
col = pd.Series(dtype=float) # For stacking
328329

329330
all_classes = pd.Series(df[filled_stat_geom.get_x_col()].unique()) # type: ignore
330331
if len(filled_stat_geom.continuous_scales) > 0:
331332
raise GGException("continuous_scales > 0")
332333

333-
for keys, sub_df in grouped: # type: ignore
334-
apply_style(style, sub_df, filled_stat_geom.discrete_scales, [(keys[0], VString(i)) for i in grouped.groups]) # type: ignore
334+
for keys in sorted_keys: # type: ignore
335+
sub_df = grouped.get_group(keys) # type: ignore
336+
key_values = list(product(filled_stat_geom.map_discrete_columns, [keys])) # type: ignore
337+
current_style = apply_style(
338+
deepcopy(style), sub_df, filled_stat_geom.discrete_scales, key_values
339+
) # type: ignore
335340

336341
weight_scale = filled_scales.get_weight_scale(
337342
filled_stat_geom.geom, optional=True
@@ -356,7 +361,7 @@ def _filled_count_geom_map(
356361
filled_geom.gg_data.yield_data[keys] = apply_cont_scale_if_any( # type: ignore
357362
yield_df,
358363
filled_stat_geom.continuous_scales,
359-
style,
364+
current_style,
360365
filled_stat_geom.geom.geom_type,
361366
to_clone=True,
362367
)
@@ -527,6 +532,7 @@ def _filled_smooth_geom_map(
527532
style: "GGStyle",
528533
) -> "FilledGeom":
529534
from python_ggplot.gg.styles.utils import apply_style
535+
530536
grouped = df.groupby(filled_stat_geom.map_discrete_columns, sort=True) # type: ignore
531537
sorted_keys = sorted(grouped.groups.keys(), reverse=True) # type: ignore
532538
col = pd.Series(dtype=float) # type: ignore
@@ -536,7 +542,7 @@ def _filled_smooth_geom_map(
536542
key_values = list(product(filled_stat_geom.map_discrete_columns, [keys])) # type: ignore
537543
current_style = apply_style(
538544
deepcopy(style), sub_df, filled_stat_geom.discrete_scales, key_values
539-
) # type: ignore
545+
) # type: ignore
540546

541547
yield_df = sub_df.copy() # type: ignore
542548

src/python_ggplot/gg/scales/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,6 @@ def get_name(self) -> str:
709709
name = scale.gg_data.get_name()
710710
if name:
711711
return name
712-
print(self)
713712
raise GGException("No name found")
714713

715714

src/python_ggplot/gg/scales/collect_and_fill.py

-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
series_is_str,
2929
series_value_type,
3030
)
31-
from python_ggplot.gg.geom.base import post_process_scales
3231
from python_ggplot.gg.scales import (
3332
AlphaScale,
3433
GGScale,
@@ -1090,5 +1089,4 @@ def fill_field(field_name: str, arg: List[GGScale]) -> None:
10901089
if plot.facet is not None:
10911090
add_facets(filled_scales_result, plot)
10921091

1093-
post_process_scales(filled_scales_result, plot)
10941092
return filled_scales_result

src/python_ggplot/graphics/views.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ def get_coords(self) -> Dict[Any, Any]:
130130
def gather_coords(self) -> Dict[Any, Any]:
131131
return self.get_coords()
132132

133-
134-
def find_go_by_go_by_filter(self, filter_: Callable[[GraphicsObject], bool], recursive: bool = True):
133+
def find_go_by_go_by_filter(
134+
self, filter_: Callable[[GraphicsObject], bool], recursive: bool = True
135+
):
135136
result: List[GraphicsObject] = []
136137
for object in self.objects:
137138
if filter_(object):

src/python_ggplot/public_interface/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
VLinearData,
6565
)
6666
from python_ggplot.gg.drawing import create_gobj_from_geom
67-
from python_ggplot.gg.geom.base import FilledGeom, Geom, GeomType
67+
from python_ggplot.gg.geom.base import FilledGeom, Geom, GeomType, post_process_scales
6868
from python_ggplot.gg.scales import FillColorScaleValue, ScaleValue
6969
from python_ggplot.gg.scales.base import (
7070
ColorScale,
@@ -1869,6 +1869,7 @@ def ggcreate(plot: GgPlot, width: float = 640.0, height: float = 480.0) -> PlotV
18691869
raise GGException("Please use at least one `geom`!")
18701870

18711871
filled_scales: FilledScales = _collect_scales(plot)
1872+
post_process_scales(filled_scales, plot)
18721873
theme = build_theme(filled_scales, plot)
18731874

18741875
coord_input = CoordsInput()

tests/test_plots.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def test_geom_bar():
3434
ggdraw_plot(res, plots_path / "geom_bar.png")
3535

3636

37+
def test_geom_bar_fill():
38+
mpg = pd.read_csv(data_path / "mpg.csv")
39+
plot = ggplot(mpg, aes("class", fill="drv")) + geom_bar()
40+
res = ggcreate(plot)
41+
ggdraw_plot(res, plots_path / "geom_bar_fill.png")
42+
# print(res.view.find_go_by_go_name("geom_bar_rect"))
43+
44+
3745
@pytest.mark.xfail(reason="fix")
3846
def test_geom_point_and_text():
3947
"""
@@ -85,7 +93,7 @@ def test_geom_histogram_fill():
8593

8694

8795
@pytest.mark.xfail(reason="")
88-
def test_geom_bar_fill():
96+
def test_geom_bar_fill_y_only():
8997
# Fill the bars
9098
# Fourth plot here https://ggplot2.tidyverse.org/reference/geom_bar.html
9199
mpg = pd.read_csv(data_path / "mpg.csv")

0 commit comments

Comments
 (0)