mirror of
https://github.com/pim-n/pg-rad
synced 2026-03-23 21:58:12 +01:00
move ConfigParser to inputparser module and introduce config specs classes for structured config handling
This commit is contained in:
298
src/pg_rad/inputparser/parser.py
Normal file
298
src/pg_rad/inputparser/parser.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from pg_rad.exceptions.exceptions import MissingConfigKeyError, DimensionError
|
||||||
|
from pg_rad.configs import defaults
|
||||||
|
|
||||||
|
from .specs import (
|
||||||
|
MetadataSpec,
|
||||||
|
RuntimeSpec,
|
||||||
|
SimulationOptionsSpec,
|
||||||
|
SegmentSpec,
|
||||||
|
PathSpec,
|
||||||
|
ProceduralPathSpec,
|
||||||
|
CSVPathSpec,
|
||||||
|
SourceSpec,
|
||||||
|
AbsolutePointSourceSpec,
|
||||||
|
RelativePointSourceSpec,
|
||||||
|
SimulationSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParser:
|
||||||
|
_ALLOWED_ROOT_KEYS = {
|
||||||
|
"name",
|
||||||
|
"speed",
|
||||||
|
"acquisition_time",
|
||||||
|
"path",
|
||||||
|
"sources",
|
||||||
|
"options",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config_source: str):
|
||||||
|
self.config = self._load_yaml(config_source)
|
||||||
|
|
||||||
|
def parse(self) -> SimulationSpec:
|
||||||
|
self._warn_unknown_keys(
|
||||||
|
section="global",
|
||||||
|
provided=set(self.config.keys()),
|
||||||
|
allowed=self._ALLOWED_ROOT_KEYS,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = self._parse_metadata()
|
||||||
|
runtime = self._parse_runtime()
|
||||||
|
options = self._parse_options()
|
||||||
|
path = self._parse_path()
|
||||||
|
sources = self._parse_point_sources()
|
||||||
|
|
||||||
|
return SimulationSpec(
|
||||||
|
metadata=metadata,
|
||||||
|
runtime=runtime,
|
||||||
|
options=options,
|
||||||
|
path=path,
|
||||||
|
point_sources=sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_yaml(self, config_source: str) -> Dict[str, Any]:
|
||||||
|
if os.path.exists(config_source):
|
||||||
|
with open(config_source) as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
return yaml.safe_load(config_source)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_metadata(self) -> MetadataSpec:
|
||||||
|
try:
|
||||||
|
return MetadataSpec(
|
||||||
|
name=self.config["name"]
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
raise MissingConfigKeyError("global", {"name"}) from e
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_runtime(self) -> RuntimeSpec:
|
||||||
|
required = {"speed", "acquisition_time"}
|
||||||
|
missing = required - self.config.keys()
|
||||||
|
if missing:
|
||||||
|
raise MissingConfigKeyError("global", missing)
|
||||||
|
|
||||||
|
return RuntimeSpec(
|
||||||
|
speed=float(self.config["speed"]),
|
||||||
|
acquisition_time=float(self.config["acquisition_time"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_options(self) -> SimulationOptionsSpec:
|
||||||
|
options = self.config.get("options", {})
|
||||||
|
|
||||||
|
allowed = {"air_density", "seed"}
|
||||||
|
self._warn_unknown_keys(
|
||||||
|
section="options",
|
||||||
|
provided=set(options.keys()),
|
||||||
|
allowed=allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SimulationOptionsSpec(
|
||||||
|
air_density=float(options.get(
|
||||||
|
"air_density",
|
||||||
|
defaults.DEFAULT_AIR_DENSITY
|
||||||
|
)),
|
||||||
|
seed=options.get("seed"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_path(self) -> PathSpec:
|
||||||
|
allowed_csv = {"file", "east_col_name", "north_col_name", "z"}
|
||||||
|
allowed_proc = {"segments", "length", "z"}
|
||||||
|
|
||||||
|
if "path" not in self.config:
|
||||||
|
raise MissingConfigKeyError("global", {"path"})
|
||||||
|
|
||||||
|
path = self.config["path"]
|
||||||
|
|
||||||
|
if "file" in path:
|
||||||
|
self._warn_unknown_keys(
|
||||||
|
section="path (csv)",
|
||||||
|
provided=set(path.keys()),
|
||||||
|
allowed=allowed_csv,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CSVPathSpec(
|
||||||
|
file=path["file"],
|
||||||
|
east_col_name=path["east_col_name"],
|
||||||
|
north_col_name=path["north_col_name"],
|
||||||
|
z=path.get("z", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "segments" in path:
|
||||||
|
self._warn_unknown_keys(
|
||||||
|
section="path (procedural)",
|
||||||
|
provided=set(path.keys()),
|
||||||
|
allowed=allowed_proc,
|
||||||
|
)
|
||||||
|
return self._parse_procedural_path(path)
|
||||||
|
|
||||||
|
raise ValueError("Invalid path configuration.")
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_procedural_path(
|
||||||
|
self,
|
||||||
|
path: Dict[str, Any]
|
||||||
|
) -> ProceduralPathSpec:
|
||||||
|
raw_segments = path["segments"]
|
||||||
|
raw_length = path.get("length")
|
||||||
|
|
||||||
|
if raw_length is None:
|
||||||
|
raise MissingConfigKeyError("path", {"length"})
|
||||||
|
|
||||||
|
if isinstance(raw_length, int | float):
|
||||||
|
raw_length = [float(raw_length)]
|
||||||
|
|
||||||
|
segments = self._process_segment_angles(raw_segments)
|
||||||
|
lengths = self._process_segment_lengths(raw_length, len(segments))
|
||||||
|
resolved_segments = self._combine_segments_lengths(segments, lengths)
|
||||||
|
|
||||||
|
return ProceduralPathSpec(
|
||||||
|
segments=resolved_segments,
|
||||||
|
z=path.get("z", defaults.DEFAULT_PATH_HEIGHT),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_segment_angles(
|
||||||
|
self,
|
||||||
|
raw_segments: List[Union[str, dict]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
|
normalized = []
|
||||||
|
|
||||||
|
for segment in raw_segments:
|
||||||
|
|
||||||
|
if isinstance(segment, str):
|
||||||
|
normalized.append({"type": segment, "angle": None})
|
||||||
|
|
||||||
|
elif isinstance(segment, dict):
|
||||||
|
if len(segment) != 1:
|
||||||
|
raise ValueError("Invalid segment definition.")
|
||||||
|
seg_type, angle = list(segment.items())[0]
|
||||||
|
normalized.append({"type": seg_type, "angle": angle})
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid segment entry format.")
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _process_segment_lengths(
|
||||||
|
self,
|
||||||
|
raw_length_list: List[Union[int, float]],
|
||||||
|
num_segments: int
|
||||||
|
) -> List[float]:
|
||||||
|
num_lengths = len(raw_length_list)
|
||||||
|
|
||||||
|
if num_lengths == num_segments:
|
||||||
|
return raw_length_list
|
||||||
|
elif num_lengths == 1:
|
||||||
|
length_list = raw_length_list + [None] * (num_segments - 1)
|
||||||
|
return length_list
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Path length must either be a single number or a list with "
|
||||||
|
"number of elements equal to the number of segments."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _combine_segments_lengths(
|
||||||
|
self,
|
||||||
|
segments: List[Dict[str, Any]],
|
||||||
|
lengths: List[float],
|
||||||
|
) -> List[SegmentSpec]:
|
||||||
|
|
||||||
|
resolved = []
|
||||||
|
|
||||||
|
for seg, length in zip(segments, lengths):
|
||||||
|
angle = seg["angle"]
|
||||||
|
|
||||||
|
if angle is not None and not self._is_turn(seg["type"]):
|
||||||
|
raise ValueError(
|
||||||
|
f"A {seg["type"]} segment does not support an angle."
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved.append(
|
||||||
|
SegmentSpec(
|
||||||
|
type=seg["type"],
|
||||||
|
length=length,
|
||||||
|
angle=angle,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_turn(segment_type: str) -> bool:
|
||||||
|
return segment_type in {"turn_left", "turn_right"}
|
||||||
|
|
||||||
|
def _parse_point_sources(self) -> List[SourceSpec]:
|
||||||
|
source_dict = self.config.get("sources", {})
|
||||||
|
specs: List[SourceSpec] = []
|
||||||
|
|
||||||
|
for name, params in source_dict.items():
|
||||||
|
|
||||||
|
required = {"activity_MBq", "isotope", "position"}
|
||||||
|
missing = required - params.keys()
|
||||||
|
if missing:
|
||||||
|
raise MissingConfigKeyError(name, missing)
|
||||||
|
|
||||||
|
position = params["position"]
|
||||||
|
|
||||||
|
if isinstance(position, list):
|
||||||
|
if len(position) != 3:
|
||||||
|
raise DimensionError(
|
||||||
|
"Absolute position must be [x, y, z]."
|
||||||
|
)
|
||||||
|
|
||||||
|
specs.append(
|
||||||
|
AbsolutePointSourceSpec(
|
||||||
|
name=name,
|
||||||
|
activity_MBq=float(params["activity_MBq"]),
|
||||||
|
isotope=params["isotope"],
|
||||||
|
x=float(position[0]),
|
||||||
|
y=float(position[1]),
|
||||||
|
z=float(position[2]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(position, dict):
|
||||||
|
specs.append(
|
||||||
|
RelativePointSourceSpec(
|
||||||
|
name=name,
|
||||||
|
activity_MBq=float(params["activity_MBq"]),
|
||||||
|
isotope=params["isotope"],
|
||||||
|
along_path=float(position["along_path"]),
|
||||||
|
dist_from_path=float(position["dist_from_path"]),
|
||||||
|
side=position["side"],
|
||||||
|
z=position.get("z", defaults.DEFAULT_SOURCE_HEIGHT)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid position format for source '{name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return specs
|
||||||
|
|
||||||
|
def _warn_unknown_keys(self, section: str, provided: set, allowed: set):
|
||||||
|
unknown = provided - allowed
|
||||||
|
if unknown:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown keys in '{section}' section: {unknown}"
|
||||||
|
)
|
||||||
76
src/pg_rad/inputparser/specs.py
Normal file
76
src/pg_rad/inputparser/specs.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataSpec:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RuntimeSpec:
|
||||||
|
speed: float
|
||||||
|
acquisition_time: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulationOptionsSpec:
|
||||||
|
air_density: float = 1.243
|
||||||
|
seed: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SegmentSpec:
|
||||||
|
type: str
|
||||||
|
length: float
|
||||||
|
angle: float | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathSpec(ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProceduralPathSpec(PathSpec):
|
||||||
|
segments: list[SegmentSpec]
|
||||||
|
z: int | float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CSVPathSpec(PathSpec):
|
||||||
|
file: str
|
||||||
|
east_col_name: str
|
||||||
|
north_col_name: str
|
||||||
|
z: int | float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SourceSpec(ABC):
|
||||||
|
activity_MBq: float
|
||||||
|
isotope: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AbsolutePointSourceSpec(SourceSpec):
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
z: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RelativePointSourceSpec(SourceSpec):
|
||||||
|
along_path: float
|
||||||
|
dist_from_path: float
|
||||||
|
side: str
|
||||||
|
z: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulationSpec:
|
||||||
|
metadata: MetadataSpec
|
||||||
|
runtime: RuntimeSpec
|
||||||
|
options: SimulationOptionsSpec
|
||||||
|
path: PathSpec
|
||||||
|
point_sources: list[SourceSpec]
|
||||||
@ -1,119 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from pg_rad.exceptions.exceptions import MissingNestedKeyError
|
|
||||||
|
|
||||||
|
|
||||||
# used to check required keys that are nested within source
|
|
||||||
REQUIRED_SOURCE_KEYS = {'activity_MBq', 'isotope', 'position'}
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigParser:
|
|
||||||
def __init__(self, config_path: str):
|
|
||||||
try:
|
|
||||||
with open(config_path) as f:
|
|
||||||
self.config = yaml.safe_load(f)
|
|
||||||
logger.debug("YAML config file loaded correctly.")
|
|
||||||
self.path_type: str = None
|
|
||||||
self.path: Dict = {}
|
|
||||||
self.sources: Dict[str, Dict[str, Any]] = {}
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.critical(f"Config file not found: {e}")
|
|
||||||
raise
|
|
||||||
except yaml.YAMLError as e:
|
|
||||||
logger.critical(f"Error parsing YAML file: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def parse(self) -> None:
|
|
||||||
try:
|
|
||||||
self._parse_required()
|
|
||||||
logger.debug("Global keys parsed.")
|
|
||||||
self._parse_path()
|
|
||||||
logger.debug("Path parsed.")
|
|
||||||
self._parse_point_sources()
|
|
||||||
logger.debug("Point sources parsed.")
|
|
||||||
except (MissingNestedKeyError, KeyError) as e:
|
|
||||||
logger.critical(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _parse_required(self) -> None:
|
|
||||||
logger.debug("Attempting to parse the required global keys.")
|
|
||||||
self.name = self.config['name']
|
|
||||||
self.speed = self.config['speed']
|
|
||||||
self.acquisition_time = self.config['acquisition_time']
|
|
||||||
self.ds = self.speed * self.acquisition_time
|
|
||||||
|
|
||||||
def _parse_path(self) -> None:
|
|
||||||
logger.debug("Attempting to parse the path.")
|
|
||||||
path = self.config['path']
|
|
||||||
|
|
||||||
if path.get('file'):
|
|
||||||
logger.debug("Experimental CSV path detected in config file.")
|
|
||||||
self.path_type = "csv"
|
|
||||||
self.path = path
|
|
||||||
elif path.get('segments'):
|
|
||||||
logger.debug("Procedural path detected in config file.")
|
|
||||||
self.path_type = "procedural"
|
|
||||||
|
|
||||||
length = path.get('length')
|
|
||||||
segments, angles = self._parse_segment_list(
|
|
||||||
path.get('segments')
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(length, list)
|
|
||||||
and len(length) != len(segments)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"the path.length subkey must either be a single number "
|
|
||||||
"(e.g. length: 100) representing the total length of the "
|
|
||||||
"path, or a list corresponding to each segment in "
|
|
||||||
"path.segments."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.path.update({
|
|
||||||
'segments': segments,
|
|
||||||
'length': length,
|
|
||||||
'angles': angles
|
|
||||||
})
|
|
||||||
|
|
||||||
def _parse_point_sources(self) -> None:
|
|
||||||
"""Parse point sources configuration."""
|
|
||||||
logger.debug("Attempting to parse the point sources.")
|
|
||||||
source_dict = self.config.get('sources')
|
|
||||||
if source_dict:
|
|
||||||
for source, params in source_dict.items():
|
|
||||||
diff = REQUIRED_SOURCE_KEYS - set(params)
|
|
||||||
if diff:
|
|
||||||
raise MissingNestedKeyError(source, diff)
|
|
||||||
|
|
||||||
self.sources[source] = params
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.info("No point sources provided.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_segment_list(
|
|
||||||
input_segments: List[Union[str, dict]]
|
|
||||||
) -> Tuple[List[str], List[float | None]]:
|
|
||||||
"""
|
|
||||||
The segments list is a list of segments and possible angles.
|
|
||||||
"""
|
|
||||||
segments = []
|
|
||||||
angles = []
|
|
||||||
|
|
||||||
for s in input_segments:
|
|
||||||
if isinstance(s, str):
|
|
||||||
segments.append(s)
|
|
||||||
angles.append(None)
|
|
||||||
elif isinstance(s, dict):
|
|
||||||
seg_name, seg_angle = list(s.items())[0]
|
|
||||||
segments.append(seg_name)
|
|
||||||
angles.append(seg_angle)
|
|
||||||
|
|
||||||
return segments, angles
|
|
||||||
Reference in New Issue
Block a user