Skip to content

Commit 50c7359

Browse files
committed
Set canvas with with a string, example '17cm'
1 parent 0367849 commit 50c7359

File tree

10 files changed

+104
-51
lines changed

10 files changed

+104
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ classifiers = [
1616
]
1717
dependencies = [
1818
"matplotlib",
19+
"pint",
1920
"plotly",
2021
"pytest",
2122
"ruff",

src/maxplotlib/backends/matplotlib/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import matplotlib.colors as mcolors
1010
import matplotlib.pyplot as plt
1111
import numpy as np
12+
import pint
1213
from matplotlib.collections import PatchCollection
1314
from mpl_toolkits.mplot3d import Axes3D
1415
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
@@ -54,7 +55,29 @@ def setup_plotstyle(
5455
plt.rcParams["ytick.major.pad"] = 8
5556

5657

57-
def set_size(width, fraction=1, ratio="golden"):
58+
# TODO: Use the other unit package
59+
# Create a UnitRegistry
60+
ureg = pint.UnitRegistry()
61+
62+
63+
def convert_to_inches(length_str):
64+
quantity = ureg(length_str) # Parse the input string
65+
return quantity.to("inch").magnitude # Convert to inches
66+
67+
68+
def _2pt(width, dpi=300):
69+
if isinstance(width, (int, float)):
70+
return width
71+
elif isinstance(width, str):
72+
length_in = convert_to_inches(width)
73+
length_pt = length_in * dpi
74+
# print(f"{length_in = } {length_pt = }")
75+
return length_pt
76+
else:
77+
raise NotImplementedError
78+
79+
80+
def set_size(width, fraction=1, ratio="golden", dpi=300):
5881
"""
5982
Sets figure dimensions to avoid scaling in LaTeX.
6083
"""
@@ -63,7 +86,7 @@ def set_size(width, fraction=1, ratio="golden"):
6386
elif width == "beamer":
6487
width_pt = 307.28987
6588
else:
66-
width_pt = width
89+
width_pt = _2pt(width=width, dpi=dpi)
6790

6891
fig_width_pt = width_pt * fraction
6992
# inches_per_pt = 1 / 72.27

src/maxplotlib/canvas/canvas.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def __init__(self, **kwargs):
2525
self._caption = kwargs.get("caption", None)
2626
self._description = kwargs.get("description", None)
2727
self._label = kwargs.get("label", None)
28-
28+
self._fontsize = kwargs.get("fontsize", 14)
2929
self._dpi = kwargs.get("dpi", 300)
30-
self._width = kwargs.get("width", 426.79135)
30+
# self._width = kwargs.get("width", 426.79135)
31+
self._width = kwargs.get("width", "17cm")
3132
self._ratio = kwargs.get("ratio", "golden")
3233
self._gridspec_kw = kwargs.get("gridspec_kw", {"wspace": 0.08, "hspace": 0.1})
3334

@@ -158,8 +159,9 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None):
158159
filename (str, optional): Filename to save the figure.
159160
show (bool): Whether to display the plot.
160161
"""
161-
fontsize = 14
162-
tex_fonts = plt_utils.setup_tex_fonts(fontsize=fontsize)
162+
163+
tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize)
164+
163165
plt_utils.setup_plotstyle(
164166
tex_fonts=tex_fonts,
165167
axes_grid=True,
@@ -172,13 +174,17 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None):
172174
fig_width, fig_height = self._figsize
173175
else:
174176
fig_width, fig_height = plt_utils.set_size(
175-
width=self._width, ratio=self._ratio
177+
width=self._width,
178+
ratio=self._ratio,
179+
dpi=self.dpi,
176180
)
177181

182+
# print(f"{(fig_width / self._dpi, fig_height / self._dpi) = }")
183+
178184
fig, axes = plt.subplots(
179185
self.nrows,
180186
self.ncols,
181-
figsize=(fig_width / self._dpi, fig_height / self._dpi),
187+
figsize=(fig_width, fig_height),
182188
squeeze=False,
183189
dpi=self._dpi,
184190
)
@@ -189,7 +195,7 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None):
189195
# ax.set_title(f"Subplot ({row}, {col})")
190196
ax.grid()
191197
# Set caption, labels, etc., if needed
192-
plt.tight_layout()
198+
# plt.tight_layout()
193199

194200
if show:
195201
plt.show()
@@ -205,9 +211,9 @@ def plot_plotly(self, show=True, savefig=None):
205211
show (bool): Whether to display the plot.
206212
savefig (str, optional): Filename to save the figure if provided.
207213
"""
208-
fontsize = 14
214+
209215
tex_fonts = plt_utils.setup_tex_fonts(
210-
fontsize=fontsize
216+
fontsize=self.fontsize
211217
) # adjust or redefine for Plotly if needed
212218

213219
# Set default width and height if not specified
@@ -217,7 +223,7 @@ def plot_plotly(self, show=True, savefig=None):
217223
fig_width, fig_height = plt_utils.set_size(
218224
width=self._width, ratio=self._ratio
219225
)
220-
print(self._width, fig_width, fig_height)
226+
# print(self._width, fig_width, fig_height)
221227
# Create subplots
222228
fig = make_subplots(
223229
rows=self.nrows,
@@ -256,6 +262,10 @@ def plot_plotly(self, show=True, savefig=None):
256262
def dpi(self):
257263
return self._dpi
258264

265+
@property
266+
def fontsize(self):
267+
return self._fontsize
268+
259269
@property
260270
def nrows(self):
261271
return self._nrows

src/maxplotlib/logo/logo.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
# import maxplotlib.canvas.canvas as canvas
22
# from maxplotlib.subfigure.tikz_figure import TikzFigure
33
from maxtikzlib.figure import TikzFigure
4+
5+
46
def tikz_logo():
57
tikz = TikzFigure()
68

7-
path_actions = ['draw', 'rounded corners', 'line width=3']
9+
path_actions = ["draw", "rounded corners", "line width=3"]
810

911
# M
10-
nodes = [[0,0],[0,3],[1,2],[2,3],[2,0]]
12+
nodes = [[0, 0], [0, 3], [1, 2], [2, 3], [2, 0]]
1113
for i, node_data in enumerate(nodes):
1214
tikz.add_node(node_data[0], node_data[1], f"M{i}", layer=0)
13-
tikz.add_path([f"M{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1)
14-
15+
tikz.add_path(
16+
[f"M{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1
17+
)
18+
1519
# P
16-
nodes = [[3,0],[3,3],[4,2.5],[4,1.5],[3,1]]
20+
nodes = [[3, 0], [3, 3], [4, 2.5], [4, 1.5], [3, 1]]
1721
for i, node_data in enumerate(nodes):
1822
tikz.add_node(node_data[0], node_data[1], f"P{i}", layer=0)
19-
tikz.add_path([f"P{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1)
20-
23+
tikz.add_path(
24+
[f"P{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1
25+
)
26+
2127
# L
22-
nodes = [[5,3],[5,0],[7,0]]
28+
nodes = [[5, 3], [5, 0], [7, 0]]
2329
for i, node_data in enumerate(nodes):
2430
tikz.add_node(node_data[0], node_data[1], f"L{i}", layer=0)
25-
tikz.add_path([f"L{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1)
31+
tikz.add_path(
32+
[f"L{i}" for i in range(len(nodes))], path_actions=path_actions, layer=1
33+
)
2634

27-
return tikz
35+
return tikz

src/maxplotlib/objects/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from abc import ABCMeta, abstractmethod
22

3+
34
class Layer(metaclass=ABCMeta):
45
def __init__(self, label):
56
self.label = label
67
self.items = []
78

9+
810
class Tikzlayer(Layer):
911
def __init__(self, label):
1012
super().__init__(label)
@@ -15,4 +17,4 @@ def generate_tikz(self):
1517
for item in self.items:
1618
tikz_script += item.to_tikz()
1719
tikz_script += f"\\end{{pgfonlayer}}{{{self.label}}}\n"
18-
return tikz_script
20+
return tikz_script

src/maxplotlib/subfigure/line_plot.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import matplotlib.pyplot as plt
12
import numpy as np
23
import plotly.graph_objects as go
34
from mpl_toolkits.axes_grid1 import make_axes_locatable
4-
import matplotlib.pyplot as plt
5+
56
import maxplotlib.subfigure.tikz_figure as tf
67

8+
79
class Node:
810
def __init__(self, x, y, label="", content="", layer=0, **kwargs):
911
self.x = x
@@ -13,6 +15,7 @@ def __init__(self, x, y, label="", content="", layer=0, **kwargs):
1315
self.layer = layer
1416
self.options = kwargs
1517

18+
1619
class Path:
1720
def __init__(
1821
self, nodes, path_actions=[], cycle=False, label="", layer=0, **kwargs
@@ -24,6 +27,7 @@ def __init__(
2427
self.label = label
2528
self.options = kwargs
2629

30+
2731
class LinePlot:
2832
def __init__(self, **kwargs):
2933
"""
@@ -56,7 +60,7 @@ def __init__(self, **kwargs):
5660
# Initialize lists to hold Node and Path objects
5761
self.nodes = []
5862
self.paths = []
59-
#self.layers = {}
63+
# self.layers = {}
6064

6165
# Counter for unnamed nodes
6266
self._node_counter = 0
@@ -77,7 +81,7 @@ def _add(self, obj, layer):
7781
else:
7882
self.layered_line_data[layer] = [obj]
7983

80-
def add_line(self, x_data, y_data, layer=0, plot_type='plot', **kwargs):
84+
def add_line(self, x_data, y_data, layer=0, plot_type="plot", **kwargs):
8185
"""
8286
Add a line to the plot.
8387
@@ -95,17 +99,17 @@ def add_line(self, x_data, y_data, layer=0, plot_type='plot', **kwargs):
9599
"kwargs": kwargs,
96100
}
97101
self._add(ld, layer)
98-
99-
def add_imshow(self, data, layer=0, plot_type='imshow', **kwargs):
102+
103+
def add_imshow(self, data, layer=0, plot_type="imshow", **kwargs):
100104
ld = {
101105
"data": np.array(data),
102106
"layer": layer,
103107
"plot_type": plot_type,
104108
"kwargs": kwargs,
105109
}
106110
self._add(ld, layer)
107-
108-
def add_patch(self, patch, layer=0, plot_type='patch', **kwargs):
111+
112+
def add_patch(self, patch, layer=0, plot_type="patch", **kwargs):
109113
ld = {
110114
"patch": patch,
111115
"layer": layer,
@@ -114,7 +118,7 @@ def add_patch(self, patch, layer=0, plot_type='patch', **kwargs):
114118
}
115119
self._add(ld, layer)
116120

117-
def add_colorbar(self, label="", layer=0, plot_type='colorbar', **kwargs):
121+
def add_colorbar(self, label="", layer=0, plot_type="colorbar", **kwargs):
118122
cb = {
119123
"label": label,
120124
"layer": layer,
@@ -159,9 +163,10 @@ def plot_matplotlib(self, ax, layers=None):
159163
**line["kwargs"],
160164
)
161165
elif line["plot_type"] == "patch":
162-
ax.add_patch(line["patch"],
163-
**line["kwargs"],
164-
)
166+
ax.add_patch(
167+
line["patch"],
168+
**line["kwargs"],
169+
)
165170
elif line["plot_type"] == "colorbar":
166171
divider = make_axes_locatable(ax)
167172
cax = divider.append_axes("right", size="5%", pad=0.05)
@@ -208,7 +213,7 @@ def plot_plotly(self):
208213
traces.append(trace)
209214

210215
return traces
211-
216+
212217
def add_node(self, x, y, label=None, content="", layer=0, **kwargs):
213218
"""
214219
Add a node to the TikZ figure.

src/maxplotlib/subfigure/tikz_figure.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def generate_tikz(self):
3636
tikz_script += f"\\end{{pgfonlayer}}{{{self.label}}}\n"
3737
return tikz_script
3838

39+
3940
class TikzWrapper:
4041
def __init__(self, raw_tikz, label="", content="", layer=0, **kwargs):
4142
self.raw_tikz = raw_tikz
@@ -47,6 +48,7 @@ def __init__(self, raw_tikz, label="", content="", layer=0, **kwargs):
4748
def to_tikz(self):
4849
return self.raw_tikz
4950

51+
5052
class Node:
5153
def __init__(self, x, y, label="", content="", layer=0, **kwargs):
5254
"""
@@ -298,10 +300,12 @@ def generate_tikz(self):
298300
tikz_script = figure_env
299301
tikz_script = self.add_tabs(tikz_script)
300302
return tikz_script
303+
301304
def savefig(self, filepath):
302305
tikz_code = self.generate_tikz()
303-
with open(filepath, 'w') as f:
306+
with open(filepath, "w") as f:
304307
f.write(tikz_code)
308+
305309
def generate_standalone(self):
306310
tikz_code = self.generate_tikz()
307311

src/maxplotlib/tests/test_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
@pytest.mark.parametrize("x", [0])
55
def import_modules(x):
66
import matplotlib
7+
78
import maxplotlib
89

910

tutorials/tutorial_06.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
import maxplotlib.canvas.canvas as canvas
3+
import matplotlib.pyplot as plt
4+
c = canvas.Canvas(width=2000, ratio=0.5)
5+
sp = c.add_subplot(grid=False, xlabel='x', ylabel='y')
6+
# sp.add_line([0, 1, 2, 3], [0, 1, 4, 9], label="Line 1",layer=1)
7+
data = np.random.random((10,10))
8+
sp.add_imshow(data, extent=[1,10,1,20],layer=1)
9+
#c.plot()
10+
c.savefig(layer_by_layer=True, filename='figures/tutorial_06_figure_01')

tutorials/tutorial_1.ipynb

Lines changed: 4 additions & 15 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)