Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ ax.set_legend(True)
canvas.show(backend="plotext")
```

## Examples

Runnable example scripts live in `examples/`:

``` bash
python examples/plotly_backend_basic.py
python examples/plotly_backend_parity.py
```

### Layers

``` python
Expand Down
23 changes: 23 additions & 0 deletions examples/plotly_backend_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

from maxplotlib import Canvas


def main() -> None:
x = np.linspace(0, 2 * np.pi, 200)

canvas = Canvas(width="12cm", ratio=0.5)
canvas.add_line(x, np.sin(x), color="royalblue", label="sin(x)")
canvas.scatter(x[::12], np.sin(x[::12]), color="tomato", label="samples")
canvas.axhline(0, color="black", linestyle="dotted")
canvas.set_title("Plotly backend (basic)")
canvas.set_xlabel("x")
canvas.set_ylabel("y")
canvas.set_grid(True)
canvas.set_legend(True)

canvas.savefig("plotly_basic.html", backend="plotly")


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions examples/plotly_backend_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import matplotlib.patches as mpatches
import numpy as np

from maxplotlib import Canvas


def main() -> None:
x = np.linspace(0.5, 10, 60)
y = np.sqrt(x)

canvas = Canvas(width="12cm", ratio=0.55)

canvas.add_line(x, y, color="steelblue", label="sqrt(x)")
canvas.errorbar(
x[::10],
y[::10],
yerr=0.15,
color="tomato",
marker="o",
label="samples ± err",
)
canvas.fill_between(x, y - 0.1, y + 0.1, color="steelblue", alpha=0.2, label="band")
canvas.vlines([2, 5, 8], ymin=0, ymax=3.5, color="gray", linestyle="dashed")
canvas.text(7.2, 2.8, "note", color="purple")
canvas.annotate(
"peak-ish", xy=(9.5, np.sqrt(9.5)), xytext=(6.0, 3.1), color="purple"
)

canvas.add_patch(
mpatches.Rectangle((1.2, 0.0), 2.5, 1.2, fill=True),
facecolor="rgba(255,0,0,0.1)",
edgecolor="crimson",
alpha=0.3,
)

canvas.set_title("Plotly backend (parity features)")
canvas.set_xlabel("x")
canvas.set_ylabel("y")
canvas.set_xscale("log")
canvas.set_grid(True)
canvas.set_legend(True)

canvas.savefig("plotly_parity.html", backend="plotly")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"pint",
"plotly",
"plotext",
"tikzfigure[vis]>=0.2.1",
"tikzfigure[vis]>=0.3.0",
]
[project.optional-dependencies]
test = [
Expand Down
230 changes: 208 additions & 22 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from plotly.subplots import make_subplots
from tikzfigure import TikzFigure

Expand Down Expand Up @@ -579,6 +580,41 @@ def text(
"""Add a text label at (x, y) on a subplot."""
self._get_or_create_subplot(row, col).text(x, y, s, layer=layer, **kwargs)

def imshow(
self,
data,
layer=0,
row: int | None = None,
col: int | None = None,
**kwargs,
):
"""Add an image/matrix plot to a subplot."""
self._get_or_create_subplot(row, col).add_imshow(data, layer=layer, **kwargs)

def add_patch(
self,
patch,
layer=0,
row: int | None = None,
col: int | None = None,
**kwargs,
):
"""Add a Matplotlib patch to a subplot."""
self._get_or_create_subplot(row, col).add_patch(patch, layer=layer, **kwargs)

def colorbar(
self,
label: str = "",
layer=0,
row: int | None = None,
col: int | None = None,
**kwargs,
):
"""Add a colorbar to the most recent imshow() on a subplot (matplotlib backend)."""
self._get_or_create_subplot(row, col).add_colorbar(
label=label, layer=layer, **kwargs
)

# ------------------------------------------------------------------
# Multi-subplot helpers
# ------------------------------------------------------------------
Expand Down Expand Up @@ -773,6 +809,34 @@ def savefig(
figure.savefig(full_filepath)
if verbose:
print(f"Saved {full_filepath}")
elif backend == "plotly":
if layer_by_layer:
layers = []
for layer in self.layers:
layers.append(layer)
full_filepath = f"{filename_no_extension}_{layers}{extension}"
fig = self.plot(
backend="plotly",
savefig=False,
layers=layers,
)
self._save_plotly(fig, full_filepath)
if verbose:
print(f"Saved {full_filepath}")
else:
if layers is None:
layers = self.layers
full_filepath = filename
else:
full_filepath = f"{filename_no_extension}_{layers}{extension}"
fig = self.plot(
backend="plotly",
savefig=False,
layers=layers,
)
self._save_plotly(fig, full_filepath)
if verbose:
print(f"Saved {full_filepath}")

def plot(
self,
Expand All @@ -797,6 +861,7 @@ def plot(
elif backend == "plotly":
return self.plot_plotly(
savefig=savefig,
layers=layers,
usetex=resolved_usetex,
verbose=verbose,
)
Expand Down Expand Up @@ -832,7 +897,11 @@ def show(
# self._matplotlib_fig.show()
elif backend == "plotly":
resolved_usetex = self._usetex if usetex is None else usetex
self.plot_plotly(savefig=False, usetex=resolved_usetex)
fig = self.plot_plotly(
savefig=False, layers=layers, usetex=resolved_usetex, verbose=verbose
)
fig.show()
return fig
elif backend == "plotext":
figure = self.plot_plotext(
savefig=False,
Expand Down Expand Up @@ -1034,6 +1103,7 @@ def plot_plotly(
self,
show=True,
savefig=None,
layers: list | None = None,
usetex: bool | None = None,
verbose: bool = False,
):
Expand Down Expand Up @@ -1063,38 +1133,134 @@ def plot_plotly(
ratio=self._ratio,
)
# print(self._width, fig_width, fig_height)
# Create subplots
# Create subplot titles in row-major order (Plotly expects rows*cols entries)
subplot_titles = [""] * (self.nrows * self.ncols)
for (row, col), sp in self._subplot_dict.items():
index = row * self.ncols + col
subplot_titles[index] = sp._title or f"({row}, {col})"

fig = make_subplots(
rows=self.nrows,
cols=self.ncols,
subplot_titles=[
sp._title or f"({row}, {col})"
for (row, col), sp in self._subplot_dict.items()
],
subplot_titles=subplot_titles,
)

# Plot each subplot and propagate axis labels/scale
axis_index = 1
for (row, col), line_plot in self._subplot_dict.items():
traces = line_plot.plot_plotly()
traces, shapes, annotations = line_plot.plot_plotly(layers=layers)
for trace in traces:
fig.add_trace(trace, row=row + 1, col=col + 1)

# Axis label keys are "xaxis", "xaxis2", "xaxis3", ...
xkey = "xaxis" if axis_index == 1 else f"xaxis{axis_index}"
ykey = "yaxis" if axis_index == 1 else f"yaxis{axis_index}"
layout_patch = {}
if line_plot._xlabel:
layout_patch[xkey] = {"title": {"text": line_plot._xlabel}}
if line_plot._ylabel:
layout_patch[ykey] = {"title": {"text": line_plot._ylabel}}
# Axis indices are row-major: (row*ncols + col + 1)
axis_index = row * self.ncols + col + 1
xref = "x" if axis_index == 1 else f"x{axis_index}"
yref = "y" if axis_index == 1 else f"y{axis_index}"

for shape in shapes:
shape = dict(shape)
if shape.get("xref") not in {"paper"}:
shape["xref"] = xref
if shape.get("yref") not in {"paper"}:
shape["yref"] = yref
fig.add_shape(shape)

for annotation in annotations:
annotation = dict(annotation)
annotation.setdefault("xref", xref)
annotation.setdefault("yref", yref)
fig.add_annotation(annotation)

# Apply per-axis config in a row/col-safe way
xaxis_kwargs = dict(
title_text=line_plot._xlabel or None,
showgrid=bool(line_plot._grid),
row=row + 1,
col=col + 1,
)
if line_plot._xaxis_scale == "log":
layout_patch.setdefault(xkey, {})["type"] = "log"
xaxis_kwargs["type"] = "log"
fig.update_xaxes(**xaxis_kwargs)

yaxis_kwargs = dict(
title_text=line_plot._ylabel or None,
showgrid=bool(line_plot._grid),
row=row + 1,
col=col + 1,
)
if line_plot._yaxis_scale == "log":
layout_patch.setdefault(ykey, {})["type"] = "log"
if layout_patch:
fig.update_layout(**layout_patch)
axis_index += 1
yaxis_kwargs["type"] = "log"
fig.update_yaxes(**yaxis_kwargs)

# Axis limits
if line_plot._xmin is not None or line_plot._xmax is not None:
x_range = [line_plot._xmin, line_plot._xmax]
if x_range[0] is not None:
x_range[0] = line_plot._transform_scalar_x(x_range[0])
if x_range[1] is not None:
x_range[1] = line_plot._transform_scalar_x(x_range[1])
if (
line_plot._xaxis_scale == "log"
and x_range[0] is not None
and x_range[1] is not None
and x_range[0] > 0
and x_range[1] > 0
):
x_range = [np.log10(x_range[0]), np.log10(x_range[1])]
fig.update_xaxes(
range=x_range,
row=row + 1,
col=col + 1,
)
if line_plot._ymin is not None or line_plot._ymax is not None:
y_range = [line_plot._ymin, line_plot._ymax]
if y_range[0] is not None:
y_range[0] = line_plot._transform_scalar_y(y_range[0])
if y_range[1] is not None:
y_range[1] = line_plot._transform_scalar_y(y_range[1])
if (
line_plot._yaxis_scale == "log"
and y_range[0] is not None
and y_range[1] is not None
and y_range[0] > 0
and y_range[1] > 0
):
y_range = [np.log10(y_range[0]), np.log10(y_range[1])]
fig.update_yaxes(
range=y_range,
row=row + 1,
col=col + 1,
)

# Custom ticks (positions + optional labels)
if line_plot._xticks is not None:
tickvals = [line_plot._transform_scalar_x(v) for v in line_plot._xticks]
fig.update_xaxes(
tickmode="array",
tickvals=tickvals,
ticktext=line_plot._xticklabels,
row=row + 1,
col=col + 1,
)
if line_plot._yticks is not None:
tickvals = [line_plot._transform_scalar_y(v) for v in line_plot._yticks]
fig.update_yaxes(
tickmode="array",
tickvals=tickvals,
ticktext=line_plot._yticklabels,
row=row + 1,
col=col + 1,
)

# Aspect ratio
if line_plot._aspect == "equal":
fig.update_yaxes(scaleanchor=xref, row=row + 1, col=col + 1)
elif isinstance(line_plot._aspect, (int, float)):
fig.update_yaxes(
scaleanchor=xref,
scaleratio=float(line_plot._aspect),
row=row + 1,
col=col + 1,
)

# Update layout settings
fig.update_layout(
Expand All @@ -1105,10 +1271,30 @@ def plot_plotly(
fig.update_layout(title=dict(text=self._suptitle, x=0.5))

if savefig:
fig.write_image(savefig)
try:
fig.write_image(savefig)
except Exception as exc:
raise RuntimeError(
"Plotly image export failed. If you are exporting to PNG/PDF/SVG, "
"install kaleido (e.g., `pip install -U kaleido`)."
) from exc

return fig

def _save_plotly(self, fig, filename: str) -> None:
_, extension = os.path.splitext(filename)
extension = extension.lower()
if extension in {".html", ".htm"}:
fig.write_html(filename)
return
try:
fig.write_image(filename)
except Exception as exc:
raise RuntimeError(
"Plotly image export failed. For PNG/PDF/SVG export, install kaleido "
"(e.g., `pip install -U kaleido`), or export to HTML instead."
) from exc

# Property getters

@property
Expand Down
Loading
Loading