diff --git a/src/pg_rad/plotting/landscape_plotter.py b/src/pg_rad/plotting/landscape_plotter.py index 8542243..4a990f3 100644 --- a/src/pg_rad/plotting/landscape_plotter.py +++ b/src/pg_rad/plotting/landscape_plotter.py @@ -1,8 +1,11 @@ import logging from matplotlib import pyplot as plt +from matplotlib.axes import Axes from matplotlib.patches import Circle +from numpy import median + from pg_rad.landscape.landscape import Landscape logger = logging.getLogger(__name__) @@ -14,7 +17,8 @@ class LandscapeSlicePlotter: landscape: Landscape, z: int = 0, show: bool = True, - save: bool = False + save: bool = False, + ax: Axes | None = None ): """Plot a top-down slice of the landscape at a height z. @@ -27,7 +31,9 @@ class LandscapeSlicePlotter: """ self.z = z - fig, ax = plt.subplots() + + if not ax: + fig, ax = plt.subplots() self._draw_base(ax, landscape) self._draw_path(ax, landscape) @@ -35,19 +41,30 @@ class LandscapeSlicePlotter: ax.set_aspect("equal") - if save: + if save and not ax: landscape_name = landscape.name.lower().replace(' ', '_') filename = f"{landscape_name}_z{self.z}.png" plt.savefig(filename) logger.info("Plot saved to file: "+filename) - if show: + if show and not ax: plt.show() - def _draw_base(self, ax, landscape): + return ax + + def _draw_base(self, ax, landscape: Landscape): width, height = landscape.size[:2] - ax.set_xlim(right=width) - ax.set_ylim(top=height) + + ax.set_xlim(right=max(width, .5*height)) + + # if the road is very flat, we center it vertically (looks better) + if median(landscape.path.y_list) == 0: + h = max(height, .5*width) + ax.set_ylim(bottom=-h//2, + top=h//2) + else: + ax.set_ylim(top=max(height, .5*width)) + ax.set_xlabel("X [m]") ax.set_ylabel("Y [m]") ax.set_title(f"Landscape (top-down, z = {self.z})") diff --git a/src/pg_rad/plotting/result_plotter.py b/src/pg_rad/plotting/result_plotter.py new file mode 100644 index 0000000..51ba366 --- /dev/null +++ b/src/pg_rad/plotting/result_plotter.py @@ -0,0 +1,100 @@ +from matplotlib import pyplot as plt +from matplotlib.gridspec import GridSpec + +from .landscape_plotter import LandscapeSlicePlotter +from pg_rad.simulator.outputs import SimulationOutput +from pg_rad.landscape.landscape import Landscape + + +class ResultPlotter: + def __init__(self, landscape: Landscape, output: SimulationOutput): + self.landscape = landscape + self.count_rate_res = output.count_rate + self.source_res = output.sources + + def plot(self, landscape_z: float = 0): + fig = plt.figure(figsize=(12, 10), constrained_layout=True) + fig.suptitle(self.landscape.name) + gs = GridSpec( + 3, + 2, + width_ratios=[0.5, 0.5], + height_ratios=[0.7, 0.15, 0.15], + hspace=0.2) + + ax1 = fig.add_subplot(gs[0, 0]) + self._draw_count_rate(ax1) + + ax2 = fig.add_subplot(gs[0, 1]) + self._plot_landscape(ax2, landscape_z) + + ax3 = fig.add_subplot(gs[1, :]) + self._draw_table(ax3) + + ax4 = fig.add_subplot(gs[2, :]) + self._draw_source_table(ax4) + + plt.tight_layout() + plt.show() + + def _plot_landscape(self, ax, z): + lp = LandscapeSlicePlotter() + ax = lp.plot(landscape=self.landscape, z=z, ax=ax, show=False) + return ax + + def _draw_count_rate(self, ax): + x = self.count_rate_res.arc_length + y = self.count_rate_res.count_rate + ax.plot(x, y, label='Count rate', color='r') + ax.set_title('Count rate') + ax.set_xlabel('Arc length s [m]') + ax.set_ylabel('Counts') + ax.legend() + + def _draw_table(self, ax): + ax.set_axis_off() + ax.set_title('Simulation parameters') + cols = ('Parameter', 'Value') + data = [ + ["Air density (kg/m^3)", round(self.landscape.air_density, 3)], + ["Total path length (m)", round(self.landscape.path.length, 3)] + ] + + ax.table( + cellText=data, + colLabels=cols, + loc='center' + ) + + return ax + + def _draw_source_table(self, ax): + ax.set_axis_off() + ax.set_title('Point sources') + + cols = ( + 'Name', + 'Isotope', + 'Activity (MBq)', + 'Position (m)', + 'Dist. to path (m)' + ) + + # this formats position to tuple + data = [ + [ + s.name, + s.isotope, + s.activity, + "("+", ".join(f"{val:.2f}" for val in s.position)+")", + round(s.dist_from_path, 2) + ] + for s in self.source_res + ] + ax.table( + cellText=data, + colLabels=cols, + loc='center' + ) + + return ax