update plotting

This commit is contained in:
Pim Nelissen
2026-02-25 14:21:03 +01:00
parent 5615914c7e
commit 9944c06466
2 changed files with 124 additions and 7 deletions

View File

@ -1,8 +1,11 @@
import logging import logging
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import Circle from matplotlib.patches import Circle
from numpy import median
from pg_rad.landscape.landscape import Landscape from pg_rad.landscape.landscape import Landscape
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,7 +17,8 @@ class LandscapeSlicePlotter:
landscape: Landscape, landscape: Landscape,
z: int = 0, z: int = 0,
show: bool = True, show: bool = True,
save: bool = False save: bool = False,
ax: Axes | None = None
): ):
"""Plot a top-down slice of the landscape at a height z. """Plot a top-down slice of the landscape at a height z.
@ -27,7 +31,9 @@ class LandscapeSlicePlotter:
""" """
self.z = z self.z = z
fig, ax = plt.subplots()
if not ax:
fig, ax = plt.subplots()
self._draw_base(ax, landscape) self._draw_base(ax, landscape)
self._draw_path(ax, landscape) self._draw_path(ax, landscape)
@ -35,19 +41,30 @@ class LandscapeSlicePlotter:
ax.set_aspect("equal") ax.set_aspect("equal")
if save: if save and not ax:
landscape_name = landscape.name.lower().replace(' ', '_') landscape_name = landscape.name.lower().replace(' ', '_')
filename = f"{landscape_name}_z{self.z}.png" filename = f"{landscape_name}_z{self.z}.png"
plt.savefig(filename) plt.savefig(filename)
logger.info("Plot saved to file: "+filename) logger.info("Plot saved to file: "+filename)
if show: if show and not ax:
plt.show() plt.show()
def _draw_base(self, ax, landscape): return ax
def _draw_base(self, ax, landscape: Landscape):
width, height = landscape.size[:2] width, height = landscape.size[:2]
ax.set_xlim(right=width)
ax.set_ylim(top=height) ax.set_xlim(right=max(width, .5*height))
# if the road is very flat, we center it vertically (looks better)
if median(landscape.path.y_list) == 0:
h = max(height, .5*width)
ax.set_ylim(bottom=-h//2,
top=h//2)
else:
ax.set_ylim(top=max(height, .5*width))
ax.set_xlabel("X [m]") ax.set_xlabel("X [m]")
ax.set_ylabel("Y [m]") ax.set_ylabel("Y [m]")
ax.set_title(f"Landscape (top-down, z = {self.z})") ax.set_title(f"Landscape (top-down, z = {self.z})")

View File

@ -0,0 +1,100 @@
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from .landscape_plotter import LandscapeSlicePlotter
from pg_rad.simulator.outputs import SimulationOutput
from pg_rad.landscape.landscape import Landscape
class ResultPlotter:
def __init__(self, landscape: Landscape, output: SimulationOutput):
self.landscape = landscape
self.count_rate_res = output.count_rate
self.source_res = output.sources
def plot(self, landscape_z: float = 0):
fig = plt.figure(figsize=(12, 10), constrained_layout=True)
fig.suptitle(self.landscape.name)
gs = GridSpec(
3,
2,
width_ratios=[0.5, 0.5],
height_ratios=[0.7, 0.15, 0.15],
hspace=0.2)
ax1 = fig.add_subplot(gs[0, 0])
self._draw_count_rate(ax1)
ax2 = fig.add_subplot(gs[0, 1])
self._plot_landscape(ax2, landscape_z)
ax3 = fig.add_subplot(gs[1, :])
self._draw_table(ax3)
ax4 = fig.add_subplot(gs[2, :])
self._draw_source_table(ax4)
plt.tight_layout()
plt.show()
def _plot_landscape(self, ax, z):
lp = LandscapeSlicePlotter()
ax = lp.plot(landscape=self.landscape, z=z, ax=ax, show=False)
return ax
def _draw_count_rate(self, ax):
x = self.count_rate_res.arc_length
y = self.count_rate_res.count_rate
ax.plot(x, y, label='Count rate', color='r')
ax.set_title('Count rate')
ax.set_xlabel('Arc length s [m]')
ax.set_ylabel('Counts')
ax.legend()
def _draw_table(self, ax):
ax.set_axis_off()
ax.set_title('Simulation parameters')
cols = ('Parameter', 'Value')
data = [
["Air density (kg/m^3)", round(self.landscape.air_density, 3)],
["Total path length (m)", round(self.landscape.path.length, 3)]
]
ax.table(
cellText=data,
colLabels=cols,
loc='center'
)
return ax
def _draw_source_table(self, ax):
ax.set_axis_off()
ax.set_title('Point sources')
cols = (
'Name',
'Isotope',
'Activity (MBq)',
'Position (m)',
'Dist. to path (m)'
)
# this formats position to tuple
data = [
[
s.name,
s.isotope,
s.activity,
"("+", ".join(f"{val:.2f}" for val in s.position)+")",
round(s.dist_from_path, 2)
]
for s in self.source_res
]
ax.table(
cellText=data,
colLabels=cols,
loc='center'
)
return ax