move ConfigParser to inputparser module and introduce config specs classes for structured config handling

This commit is contained in:
Pim Nelissen
2026-02-25 14:25:32 +01:00
parent 58de830a39
commit a74ea765d7
3 changed files with 374 additions and 119 deletions

View File

@ -0,0 +1,298 @@
import logging
import os
from typing import Any, Dict, List, Union
import yaml
from pg_rad.exceptions.exceptions import MissingConfigKeyError, DimensionError
from pg_rad.configs import defaults
from .specs import (
MetadataSpec,
RuntimeSpec,
SimulationOptionsSpec,
SegmentSpec,
PathSpec,
ProceduralPathSpec,
CSVPathSpec,
SourceSpec,
AbsolutePointSourceSpec,
RelativePointSourceSpec,
SimulationSpec,
)
logger = logging.getLogger(__name__)
class ConfigParser:
_ALLOWED_ROOT_KEYS = {
"name",
"speed",
"acquisition_time",
"path",
"sources",
"options",
}
def __init__(self, config_source: str):
self.config = self._load_yaml(config_source)
def parse(self) -> SimulationSpec:
self._warn_unknown_keys(
section="global",
provided=set(self.config.keys()),
allowed=self._ALLOWED_ROOT_KEYS,
)
metadata = self._parse_metadata()
runtime = self._parse_runtime()
options = self._parse_options()
path = self._parse_path()
sources = self._parse_point_sources()
return SimulationSpec(
metadata=metadata,
runtime=runtime,
options=options,
path=path,
point_sources=sources,
)
# ----------------------------------------------------------
def _load_yaml(self, config_source: str) -> Dict[str, Any]:
if os.path.exists(config_source):
with open(config_source) as f:
return yaml.safe_load(f)
return yaml.safe_load(config_source)
# ----------------------------------------------------------
def _parse_metadata(self) -> MetadataSpec:
try:
return MetadataSpec(
name=self.config["name"]
)
except KeyError as e:
raise MissingConfigKeyError("global", {"name"}) from e
# ----------------------------------------------------------
def _parse_runtime(self) -> RuntimeSpec:
required = {"speed", "acquisition_time"}
missing = required - self.config.keys()
if missing:
raise MissingConfigKeyError("global", missing)
return RuntimeSpec(
speed=float(self.config["speed"]),
acquisition_time=float(self.config["acquisition_time"]),
)
# ----------------------------------------------------------
def _parse_options(self) -> SimulationOptionsSpec:
options = self.config.get("options", {})
allowed = {"air_density", "seed"}
self._warn_unknown_keys(
section="options",
provided=set(options.keys()),
allowed=allowed,
)
return SimulationOptionsSpec(
air_density=float(options.get(
"air_density",
defaults.DEFAULT_AIR_DENSITY
)),
seed=options.get("seed"),
)
# ----------------------------------------------------------
def _parse_path(self) -> PathSpec:
allowed_csv = {"file", "east_col_name", "north_col_name", "z"}
allowed_proc = {"segments", "length", "z"}
if "path" not in self.config:
raise MissingConfigKeyError("global", {"path"})
path = self.config["path"]
if "file" in path:
self._warn_unknown_keys(
section="path (csv)",
provided=set(path.keys()),
allowed=allowed_csv,
)
return CSVPathSpec(
file=path["file"],
east_col_name=path["east_col_name"],
north_col_name=path["north_col_name"],
z=path.get("z", 0),
)
if "segments" in path:
self._warn_unknown_keys(
section="path (procedural)",
provided=set(path.keys()),
allowed=allowed_proc,
)
return self._parse_procedural_path(path)
raise ValueError("Invalid path configuration.")
# ----------------------------------------------------------
def _parse_procedural_path(
self,
path: Dict[str, Any]
) -> ProceduralPathSpec:
raw_segments = path["segments"]
raw_length = path.get("length")
if raw_length is None:
raise MissingConfigKeyError("path", {"length"})
if isinstance(raw_length, int | float):
raw_length = [float(raw_length)]
segments = self._process_segment_angles(raw_segments)
lengths = self._process_segment_lengths(raw_length, len(segments))
resolved_segments = self._combine_segments_lengths(segments, lengths)
return ProceduralPathSpec(
segments=resolved_segments,
z=path.get("z", defaults.DEFAULT_PATH_HEIGHT),
)
def _process_segment_angles(
self,
raw_segments: List[Union[str, dict]]
) -> List[Dict[str, Any]]:
normalized = []
for segment in raw_segments:
if isinstance(segment, str):
normalized.append({"type": segment, "angle": None})
elif isinstance(segment, dict):
if len(segment) != 1:
raise ValueError("Invalid segment definition.")
seg_type, angle = list(segment.items())[0]
normalized.append({"type": seg_type, "angle": angle})
else:
raise ValueError("Invalid segment entry format.")
return normalized
def _process_segment_lengths(
self,
raw_length_list: List[Union[int, float]],
num_segments: int
) -> List[float]:
num_lengths = len(raw_length_list)
if num_lengths == num_segments:
return raw_length_list
elif num_lengths == 1:
length_list = raw_length_list + [None] * (num_segments - 1)
return length_list
else:
raise ValueError(
"Path length must either be a single number or a list with "
"number of elements equal to the number of segments."
)
def _combine_segments_lengths(
self,
segments: List[Dict[str, Any]],
lengths: List[float],
) -> List[SegmentSpec]:
resolved = []
for seg, length in zip(segments, lengths):
angle = seg["angle"]
if angle is not None and not self._is_turn(seg["type"]):
raise ValueError(
f"A {seg["type"]} segment does not support an angle."
)
resolved.append(
SegmentSpec(
type=seg["type"],
length=length,
angle=angle,
)
)
return resolved
@staticmethod
def _is_turn(segment_type: str) -> bool:
return segment_type in {"turn_left", "turn_right"}
def _parse_point_sources(self) -> List[SourceSpec]:
source_dict = self.config.get("sources", {})
specs: List[SourceSpec] = []
for name, params in source_dict.items():
required = {"activity_MBq", "isotope", "position"}
missing = required - params.keys()
if missing:
raise MissingConfigKeyError(name, missing)
position = params["position"]
if isinstance(position, list):
if len(position) != 3:
raise DimensionError(
"Absolute position must be [x, y, z]."
)
specs.append(
AbsolutePointSourceSpec(
name=name,
activity_MBq=float(params["activity_MBq"]),
isotope=params["isotope"],
x=float(position[0]),
y=float(position[1]),
z=float(position[2]),
)
)
elif isinstance(position, dict):
specs.append(
RelativePointSourceSpec(
name=name,
activity_MBq=float(params["activity_MBq"]),
isotope=params["isotope"],
along_path=float(position["along_path"]),
dist_from_path=float(position["dist_from_path"]),
side=position["side"],
z=position.get("z", defaults.DEFAULT_SOURCE_HEIGHT)
)
)
else:
raise ValueError(
f"Invalid position format for source '{name}'."
)
return specs
def _warn_unknown_keys(self, section: str, provided: set, allowed: set):
unknown = provided - allowed
if unknown:
logger.warning(
f"Unknown keys in '{section}' section: {unknown}"
)

View File

@ -0,0 +1,76 @@
from abc import ABC
from dataclasses import dataclass
@dataclass
class MetadataSpec:
name: str
@dataclass
class RuntimeSpec:
speed: float
acquisition_time: float
@dataclass
class SimulationOptionsSpec:
air_density: float = 1.243
seed: int | None = None
@dataclass
class SegmentSpec:
type: str
length: float
angle: float | None
@dataclass
class PathSpec(ABC):
pass
@dataclass
class ProceduralPathSpec(PathSpec):
segments: list[SegmentSpec]
z: int | float
@dataclass
class CSVPathSpec(PathSpec):
file: str
east_col_name: str
north_col_name: str
z: int | float
@dataclass
class SourceSpec(ABC):
activity_MBq: float
isotope: str
name: str
@dataclass
class AbsolutePointSourceSpec(SourceSpec):
x: float
y: float
z: float
@dataclass
class RelativePointSourceSpec(SourceSpec):
along_path: float
dist_from_path: float
side: str
z: float
@dataclass
class SimulationSpec:
metadata: MetadataSpec
runtime: RuntimeSpec
options: SimulationOptionsSpec
path: PathSpec
point_sources: list[SourceSpec]

View File

@ -1,119 +0,0 @@
import logging
from typing import Any, Dict, List, Tuple, Union
import yaml
from pg_rad.exceptions.exceptions import MissingNestedKeyError
# used to check required keys that are nested within source
REQUIRED_SOURCE_KEYS = {'activity_MBq', 'isotope', 'position'}
logger = logging.getLogger(__name__)
class ConfigParser:
def __init__(self, config_path: str):
try:
with open(config_path) as f:
self.config = yaml.safe_load(f)
logger.debug("YAML config file loaded correctly.")
self.path_type: str = None
self.path: Dict = {}
self.sources: Dict[str, Dict[str, Any]] = {}
except FileNotFoundError as e:
logger.critical(f"Config file not found: {e}")
raise
except yaml.YAMLError as e:
logger.critical(f"Error parsing YAML file: {e}")
raise
def parse(self) -> None:
try:
self._parse_required()
logger.debug("Global keys parsed.")
self._parse_path()
logger.debug("Path parsed.")
self._parse_point_sources()
logger.debug("Point sources parsed.")
except (MissingNestedKeyError, KeyError) as e:
logger.critical(e)
raise
def _parse_required(self) -> None:
logger.debug("Attempting to parse the required global keys.")
self.name = self.config['name']
self.speed = self.config['speed']
self.acquisition_time = self.config['acquisition_time']
self.ds = self.speed * self.acquisition_time
def _parse_path(self) -> None:
logger.debug("Attempting to parse the path.")
path = self.config['path']
if path.get('file'):
logger.debug("Experimental CSV path detected in config file.")
self.path_type = "csv"
self.path = path
elif path.get('segments'):
logger.debug("Procedural path detected in config file.")
self.path_type = "procedural"
length = path.get('length')
segments, angles = self._parse_segment_list(
path.get('segments')
)
if (
isinstance(length, list)
and len(length) != len(segments)
):
raise ValueError(
"the path.length subkey must either be a single number "
"(e.g. length: 100) representing the total length of the "
"path, or a list corresponding to each segment in "
"path.segments."
)
self.path.update({
'segments': segments,
'length': length,
'angles': angles
})
def _parse_point_sources(self) -> None:
"""Parse point sources configuration."""
logger.debug("Attempting to parse the point sources.")
source_dict = self.config.get('sources')
if source_dict:
for source, params in source_dict.items():
diff = REQUIRED_SOURCE_KEYS - set(params)
if diff:
raise MissingNestedKeyError(source, diff)
self.sources[source] = params
else:
logger.info("No point sources provided.")
@staticmethod
def _parse_segment_list(
input_segments: List[Union[str, dict]]
) -> Tuple[List[str], List[float | None]]:
"""
The segments list is a list of segments and possible angles.
"""
segments = []
angles = []
for s in input_segments:
if isinstance(s, str):
segments.append(s)
angles.append(None)
elif isinstance(s, dict):
seg_name, seg_angle = list(s.items())[0]
segments.append(seg_name)
angles.append(seg_angle)
return segments, angles