add save and show options to Plotter

This commit is contained in:
Pim Nelissen
2026-02-17 09:41:54 +01:00
parent f49b2a5f8a
commit 3fd2eafb2a

View File

@ -5,17 +5,24 @@ from matplotlib.patches import Circle
from pg_rad.landscape import Landscape from pg_rad.landscape import Landscape
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LandscapeSlicePlotter: class LandscapeSlicePlotter:
def plot(self, landscape: Landscape, z: int = 0): def plot(
self,
landscape: Landscape,
z: int = 0,
show: bool = True,
save: bool = False
):
"""Plot a top-down slice of the landscape at a height z. """Plot a top-down slice of the landscape at a height z.
Args: Args:
landscape (Landscape): the landscape to plot landscape (Landscape): the landscape to plot
z (int, optional): Height at which to plot slice. Defaults to 0. z (int, optional): Height at which to plot slice. Defaults to 0.
show (bool, optional): Show the plot. Defaults to True.
save (bool, optional): Save the plot. Defaults to False.
""" """ """ """
""" """
@ -27,7 +34,15 @@ class LandscapeSlicePlotter:
self._draw_point_sources(ax, landscape) self._draw_point_sources(ax, landscape)
ax.set_aspect("equal") ax.set_aspect("equal")
plt.show()
if save:
name = landscape.name.lower().replace(' ', '_')
plt.savefig(
f"{name}_z{self.z}.png"
)
if show:
plt.show()
def _draw_base(self, ax, landscape): def _draw_base(self, ax, landscape):
width, height = landscape.size[:2] width, height = landscape.size[:2]
@ -38,7 +53,7 @@ class LandscapeSlicePlotter:
ax.set_title(f"Landscape (top-down, z = {self.z})") ax.set_title(f"Landscape (top-down, z = {self.z})")
def _draw_path(self, ax, landscape): def _draw_path(self, ax, landscape):
if landscape.path.z < self.z: if landscape.path.z <= self.z:
ax.plot(landscape.path.x_list, landscape.path.y_list, 'bo-') ax.plot(landscape.path.x_list, landscape.path.y_list, 'bo-')
else: else:
logger.warning( logger.warning(
@ -48,18 +63,19 @@ class LandscapeSlicePlotter:
def _draw_point_sources(self, ax, landscape): def _draw_point_sources(self, ax, landscape):
for s in landscape.point_sources: for s in landscape.point_sources:
if s.z <= self.z: x, y, z = s.pos
if z <= self.z:
dot = Circle( dot = Circle(
(s.x, s.y), (x, y),
radius=5, radius=5,
color=s.color, color=s.color,
zorder=5 zorder=5
) )
ax.text( ax.text(
s.x + 0.06, x + 0.06,
s.y + 0.06, y + 0.06,
s.name, s.name+", z="+str(z),
color=s.color, color=s.color,
fontsize=10, fontsize=10,
ha="left", ha="left",