update plotting. add export functionality. update main to work with new plotting and saving/export functionality.

This commit is contained in:
Pim Nelissen
2026-03-31 10:48:54 +02:00
parent 7ed12989f4
commit 5bcf1778ea
3 changed files with 151 additions and 50 deletions

View File

@ -17,6 +17,7 @@ from pg_rad.inputparser.parser import ConfigParser
from pg_rad.landscape.director import LandscapeDirector
from pg_rad.plotting.result_plotter import ResultPlotter
from pg_rad.simulator.engine import SimulationEngine
from pg_rad.utils.export import generate_folder_name, save_results
def main():
@ -39,9 +40,14 @@ def main():
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
)
parser.add_argument(
"--saveplot",
"--showplots",
action="store_true",
help="Save the plot or not."
help="Show the plots immediately."
)
parser.add_argument(
"--save",
action="store_true",
help="Save the outputs"
)
args = parser.parse_args()
@ -49,7 +55,7 @@ def main():
logger = logging.getLogger(__name__)
if args.example:
example_yaml = """
input_config = """
name: Example landscape
speed: 8.33
acquisition_time: 1
@ -72,62 +78,58 @@ def main():
detector: LU_NaI_3inch
"""
elif args.config:
input_config = args.config
cp = ConfigParser(example_yaml).parse()
try:
cp = ConfigParser(input_config).parse()
landscape = LandscapeDirector.build_from_config(cp)
output = SimulationEngine(
landscape=landscape,
runtime_spec=cp.runtime,
sim_spec=cp.options,
sim_spec=cp.options
).simulate()
plotter = ResultPlotter(landscape, output)
plotter.plot()
elif args.config:
try:
cp = ConfigParser(args.config).parse()
landscape = LandscapeDirector.build_from_config(cp)
output = SimulationEngine(
landscape=landscape,
runtime_spec=cp.runtime,
sim_spec=cp.options
).simulate()
plotter = ResultPlotter(landscape, output)
if args.save:
folder_name = generate_folder_name(output)
save_results(output, folder_name)
plotter.save(folder_name)
if args.showplots:
plotter.plot()
except (
MissingConfigKeyError,
KeyError
) as e:
logger.critical(e)
logger.critical(
"The config file is missing required keys or may be an "
"invalid YAML file. Check the log above. Consult the "
"documentation for examples of how to write a config file."
)
sys.exit(1)
except (
OutOfBoundsError,
DimensionError,
InvalidIsotopeError,
InvalidConfigValueError,
NotImplementedError
) as e:
logger.critical(e)
logger.critical(
"One or more items in config are not specified correctly. "
"Please consult this log and fix the problem."
)
sys.exit(1)
except (
FileNotFoundError,
ParserError,
InvalidYAMLError
) as e:
logger.critical(e)
sys.exit(1)
except (
MissingConfigKeyError,
KeyError
) as e:
logger.critical(e)
logger.critical(
"The config file is missing required keys or may be an "
"invalid YAML file. Check the log above. Consult the "
"documentation for examples of how to write a config file."
)
sys.exit(1)
except (
OutOfBoundsError,
DimensionError,
InvalidIsotopeError,
InvalidConfigValueError,
NotImplementedError
) as e:
logger.critical(e)
logger.critical(
"One or more items in config are not specified correctly. "
"Please consult this log and fix the problem."
)
sys.exit(1)
except (
FileNotFoundError,
ParserError,
InvalidYAMLError
) as e:
logger.critical(e)
sys.exit(1)
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
from importlib.resources import files
from typing import List
import numpy as np
import pandas as pd
@ -23,6 +24,15 @@ class ResultPlotter:
plt.show()
def save(self, path: str, landscape_z: float = 0) -> None:
fig_1 = self._plot_main(landscape_z)
fig_2 = self._plot_detector()
fig_3 = self._plot_metadata()
fig_1.savefig(path+"/main.jpg")
fig_2.savefig(path+"/detector.jpg")
fig_3.savefig(path+"/metadata.jpg")
def _plot_main(self, landscape_z):
fig = plt.figure(figsize=(12, 8))
fig.suptitle(self.landscape.name)
@ -41,6 +51,7 @@ class ResultPlotter:
ax_landscape = fig.add_subplot(gs[1, :])
self._plot_landscape(ax_landscape, landscape_z)
return fig
def _plot_detector(self):
det = self.landscape.detector
@ -61,6 +72,7 @@ class ResultPlotter:
]
self._draw_angular_efficiency_polar(ax_polar, det, energies[0])
return fig
def _plot_metadata(self):
fig, axs = plt.subplots(2, 1, figsize=(10, 6))
@ -68,6 +80,7 @@ class ResultPlotter:
self._draw_table(axs[0])
self._draw_source_table(axs[1])
return fig
def _plot_landscape(self, ax, z):
lp = LandscapeSlicePlotter()
@ -84,7 +97,7 @@ class ResultPlotter:
ax.set_ylabel('CPS [s$^{-1}$]')
def _draw_counts(self, ax):
x = self.count_rate_res.acquisition_points[1:]
x = self.count_rate_res.distance[1:]
y = self.count_rate_res.integrated_counts[1:]
ax.plot(
x, y, color='r', linestyle='--',

View File

@ -0,0 +1,86 @@
from datetime import datetime as dt
import os
import logging
import re
from numpy import array, full_like
from pandas import DataFrame
from pg_rad.simulator.outputs import SimulationOutput
logger = logging.getLogger(__name__)
def generate_folder_name(sim: SimulationOutput) -> str:
formatted_sim_name = re.sub(r"\s+", '_', sim.name.lower())
folder_name = (
formatted_sim_name +
'_result_' +
dt.today().strftime('%Y%m%d_%H%M')
)
return folder_name
def save_results(sim: SimulationOutput, folder_name: str) -> None:
"""Parse all simulation output and save to a folder."""
if not os.path.exists(folder_name):
os.makedirs(folder_name)
else:
logger.warning("Folder already exists. Overwrite?")
ans = input("[type 'n' to cancel overwrite] ")
if ans.lower() == 'n':
return
df = generate_df(sim)
csv_name = generate_csv_name(sim)
df.to_csv(f"{folder_name}/{csv_name}.csv", index=False)
logger.info(f"Simulation output saved to {folder_name}!")
def generate_df(sim: SimulationOutput) -> DataFrame:
"""Parse simulation output to CSV format and the name of CSV."""
br_array = full_like(
sim.count_rate.integrated_counts,
sim.count_rate.mean_bkg_cps
)
result_df = DataFrame(
{
"East": sim.count_rate.x,
"North": sim.count_rate.y,
"ROI_P": sim.count_rate.integrated_counts,
"ROI_BR": br_array,
"Dist": sim.count_rate.distance
}
)
return result_df
def generate_csv_name(sim: SimulationOutput) -> str:
"""Generate CSV name according to Alex' specification"""
num_src = len(sim.sources)
bkg_cps = round(sim.count_rate.mean_bkg_cps)
source_param_strings = [
[
str(round(s.activity))+"MBq",
str(round(s.dist_from_path))+"m",
str(round(s.position[0])),
str(round(s.position[1])),
]
for s in sim.sources
]
if num_src == 1:
src_str = "_".join(source_param_strings[0])
else:
src_str_array = array(
[list(item) for item in zip(*source_param_strings)]
)
src_str = "_".join(src_str_array.flat)
csv_name = f"{num_src}_src_{bkg_cps}_cps_bkg_{src_str}"
return csv_name