diff --git a/src/pg_rad/inputparser/parser.py b/src/pg_rad/inputparser/parser.py new file mode 100644 index 0000000..af67067 --- /dev/null +++ b/src/pg_rad/inputparser/parser.py @@ -0,0 +1,298 @@ +import logging +import os +from typing import Any, Dict, List, Union + +import yaml + +from pg_rad.exceptions.exceptions import MissingConfigKeyError, DimensionError +from pg_rad.configs import defaults + +from .specs import ( + MetadataSpec, + RuntimeSpec, + SimulationOptionsSpec, + SegmentSpec, + PathSpec, + ProceduralPathSpec, + CSVPathSpec, + SourceSpec, + AbsolutePointSourceSpec, + RelativePointSourceSpec, + SimulationSpec, +) + + +logger = logging.getLogger(__name__) + + +class ConfigParser: + _ALLOWED_ROOT_KEYS = { + "name", + "speed", + "acquisition_time", + "path", + "sources", + "options", + } + + def __init__(self, config_source: str): + self.config = self._load_yaml(config_source) + + def parse(self) -> SimulationSpec: + self._warn_unknown_keys( + section="global", + provided=set(self.config.keys()), + allowed=self._ALLOWED_ROOT_KEYS, + ) + + metadata = self._parse_metadata() + runtime = self._parse_runtime() + options = self._parse_options() + path = self._parse_path() + sources = self._parse_point_sources() + + return SimulationSpec( + metadata=metadata, + runtime=runtime, + options=options, + path=path, + point_sources=sources, + ) + + # ---------------------------------------------------------- + + def _load_yaml(self, config_source: str) -> Dict[str, Any]: + if os.path.exists(config_source): + with open(config_source) as f: + return yaml.safe_load(f) + return yaml.safe_load(config_source) + + # ---------------------------------------------------------- + + def _parse_metadata(self) -> MetadataSpec: + try: + return MetadataSpec( + name=self.config["name"] + ) + except KeyError as e: + raise MissingConfigKeyError("global", {"name"}) from e + + # ---------------------------------------------------------- + + def _parse_runtime(self) -> RuntimeSpec: + required = {"speed", "acquisition_time"} + missing = required - self.config.keys() + if missing: + raise MissingConfigKeyError("global", missing) + + return RuntimeSpec( + speed=float(self.config["speed"]), + acquisition_time=float(self.config["acquisition_time"]), + ) + + # ---------------------------------------------------------- + + def _parse_options(self) -> SimulationOptionsSpec: + options = self.config.get("options", {}) + + allowed = {"air_density", "seed"} + self._warn_unknown_keys( + section="options", + provided=set(options.keys()), + allowed=allowed, + ) + + return SimulationOptionsSpec( + air_density=float(options.get( + "air_density", + defaults.DEFAULT_AIR_DENSITY + )), + seed=options.get("seed"), + ) + + # ---------------------------------------------------------- + + def _parse_path(self) -> PathSpec: + allowed_csv = {"file", "east_col_name", "north_col_name", "z"} + allowed_proc = {"segments", "length", "z"} + + if "path" not in self.config: + raise MissingConfigKeyError("global", {"path"}) + + path = self.config["path"] + + if "file" in path: + self._warn_unknown_keys( + section="path (csv)", + provided=set(path.keys()), + allowed=allowed_csv, + ) + + return CSVPathSpec( + file=path["file"], + east_col_name=path["east_col_name"], + north_col_name=path["north_col_name"], + z=path.get("z", 0), + ) + + if "segments" in path: + self._warn_unknown_keys( + section="path (procedural)", + provided=set(path.keys()), + allowed=allowed_proc, + ) + return self._parse_procedural_path(path) + + raise ValueError("Invalid path configuration.") + + # ---------------------------------------------------------- + + def _parse_procedural_path( + self, + path: Dict[str, Any] + ) -> ProceduralPathSpec: + raw_segments = path["segments"] + raw_length = path.get("length") + + if raw_length is None: + raise MissingConfigKeyError("path", {"length"}) + + if isinstance(raw_length, int | float): + raw_length = [float(raw_length)] + + segments = self._process_segment_angles(raw_segments) + lengths = self._process_segment_lengths(raw_length, len(segments)) + resolved_segments = self._combine_segments_lengths(segments, lengths) + + return ProceduralPathSpec( + segments=resolved_segments, + z=path.get("z", defaults.DEFAULT_PATH_HEIGHT), + ) + + def _process_segment_angles( + self, + raw_segments: List[Union[str, dict]] + ) -> List[Dict[str, Any]]: + + normalized = [] + + for segment in raw_segments: + + if isinstance(segment, str): + normalized.append({"type": segment, "angle": None}) + + elif isinstance(segment, dict): + if len(segment) != 1: + raise ValueError("Invalid segment definition.") + seg_type, angle = list(segment.items())[0] + normalized.append({"type": seg_type, "angle": angle}) + + else: + raise ValueError("Invalid segment entry format.") + + return normalized + + def _process_segment_lengths( + self, + raw_length_list: List[Union[int, float]], + num_segments: int + ) -> List[float]: + num_lengths = len(raw_length_list) + + if num_lengths == num_segments: + return raw_length_list + elif num_lengths == 1: + length_list = raw_length_list + [None] * (num_segments - 1) + return length_list + else: + raise ValueError( + "Path length must either be a single number or a list with " + "number of elements equal to the number of segments." + ) + + def _combine_segments_lengths( + self, + segments: List[Dict[str, Any]], + lengths: List[float], + ) -> List[SegmentSpec]: + + resolved = [] + + for seg, length in zip(segments, lengths): + angle = seg["angle"] + + if angle is not None and not self._is_turn(seg["type"]): + raise ValueError( + f"A {seg["type"]} segment does not support an angle." + ) + + resolved.append( + SegmentSpec( + type=seg["type"], + length=length, + angle=angle, + ) + ) + + return resolved + + @staticmethod + def _is_turn(segment_type: str) -> bool: + return segment_type in {"turn_left", "turn_right"} + + def _parse_point_sources(self) -> List[SourceSpec]: + source_dict = self.config.get("sources", {}) + specs: List[SourceSpec] = [] + + for name, params in source_dict.items(): + + required = {"activity_MBq", "isotope", "position"} + missing = required - params.keys() + if missing: + raise MissingConfigKeyError(name, missing) + + position = params["position"] + + if isinstance(position, list): + if len(position) != 3: + raise DimensionError( + "Absolute position must be [x, y, z]." + ) + + specs.append( + AbsolutePointSourceSpec( + name=name, + activity_MBq=float(params["activity_MBq"]), + isotope=params["isotope"], + x=float(position[0]), + y=float(position[1]), + z=float(position[2]), + ) + ) + + elif isinstance(position, dict): + specs.append( + RelativePointSourceSpec( + name=name, + activity_MBq=float(params["activity_MBq"]), + isotope=params["isotope"], + along_path=float(position["along_path"]), + dist_from_path=float(position["dist_from_path"]), + side=position["side"], + z=position.get("z", defaults.DEFAULT_SOURCE_HEIGHT) + ) + ) + + else: + raise ValueError( + f"Invalid position format for source '{name}'." + ) + + return specs + + def _warn_unknown_keys(self, section: str, provided: set, allowed: set): + unknown = provided - allowed + if unknown: + logger.warning( + f"Unknown keys in '{section}' section: {unknown}" + ) diff --git a/src/pg_rad/inputparser/specs.py b/src/pg_rad/inputparser/specs.py new file mode 100644 index 0000000..f483b6b --- /dev/null +++ b/src/pg_rad/inputparser/specs.py @@ -0,0 +1,76 @@ +from abc import ABC +from dataclasses import dataclass + + +@dataclass +class MetadataSpec: + name: str + + +@dataclass +class RuntimeSpec: + speed: float + acquisition_time: float + + +@dataclass +class SimulationOptionsSpec: + air_density: float = 1.243 + seed: int | None = None + + +@dataclass +class SegmentSpec: + type: str + length: float + angle: float | None + + +@dataclass +class PathSpec(ABC): + pass + + +@dataclass +class ProceduralPathSpec(PathSpec): + segments: list[SegmentSpec] + z: int | float + + +@dataclass +class CSVPathSpec(PathSpec): + file: str + east_col_name: str + north_col_name: str + z: int | float + + +@dataclass +class SourceSpec(ABC): + activity_MBq: float + isotope: str + name: str + + +@dataclass +class AbsolutePointSourceSpec(SourceSpec): + x: float + y: float + z: float + + +@dataclass +class RelativePointSourceSpec(SourceSpec): + along_path: float + dist_from_path: float + side: str + z: float + + +@dataclass +class SimulationSpec: + metadata: MetadataSpec + runtime: RuntimeSpec + options: SimulationOptionsSpec + path: PathSpec + point_sources: list[SourceSpec] diff --git a/src/pg_rad/landscape/config_parser.py b/src/pg_rad/landscape/config_parser.py deleted file mode 100644 index d7ecffd..0000000 --- a/src/pg_rad/landscape/config_parser.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging -from typing import Any, Dict, List, Tuple, Union - -import yaml - -from pg_rad.exceptions.exceptions import MissingNestedKeyError - - -# used to check required keys that are nested within source -REQUIRED_SOURCE_KEYS = {'activity_MBq', 'isotope', 'position'} - - -logger = logging.getLogger(__name__) - - -class ConfigParser: - def __init__(self, config_path: str): - try: - with open(config_path) as f: - self.config = yaml.safe_load(f) - logger.debug("YAML config file loaded correctly.") - self.path_type: str = None - self.path: Dict = {} - self.sources: Dict[str, Dict[str, Any]] = {} - except FileNotFoundError as e: - logger.critical(f"Config file not found: {e}") - raise - except yaml.YAMLError as e: - logger.critical(f"Error parsing YAML file: {e}") - raise - - def parse(self) -> None: - try: - self._parse_required() - logger.debug("Global keys parsed.") - self._parse_path() - logger.debug("Path parsed.") - self._parse_point_sources() - logger.debug("Point sources parsed.") - except (MissingNestedKeyError, KeyError) as e: - logger.critical(e) - raise - - def _parse_required(self) -> None: - logger.debug("Attempting to parse the required global keys.") - self.name = self.config['name'] - self.speed = self.config['speed'] - self.acquisition_time = self.config['acquisition_time'] - self.ds = self.speed * self.acquisition_time - - def _parse_path(self) -> None: - logger.debug("Attempting to parse the path.") - path = self.config['path'] - - if path.get('file'): - logger.debug("Experimental CSV path detected in config file.") - self.path_type = "csv" - self.path = path - elif path.get('segments'): - logger.debug("Procedural path detected in config file.") - self.path_type = "procedural" - - length = path.get('length') - segments, angles = self._parse_segment_list( - path.get('segments') - ) - - if ( - isinstance(length, list) - and len(length) != len(segments) - ): - raise ValueError( - "the path.length subkey must either be a single number " - "(e.g. length: 100) representing the total length of the " - "path, or a list corresponding to each segment in " - "path.segments." - ) - - self.path.update({ - 'segments': segments, - 'length': length, - 'angles': angles - }) - - def _parse_point_sources(self) -> None: - """Parse point sources configuration.""" - logger.debug("Attempting to parse the point sources.") - source_dict = self.config.get('sources') - if source_dict: - for source, params in source_dict.items(): - diff = REQUIRED_SOURCE_KEYS - set(params) - if diff: - raise MissingNestedKeyError(source, diff) - - self.sources[source] = params - - else: - logger.info("No point sources provided.") - - @staticmethod - def _parse_segment_list( - input_segments: List[Union[str, dict]] - ) -> Tuple[List[str], List[float | None]]: - """ - The segments list is a list of segments and possible angles. - """ - segments = [] - angles = [] - - for s in input_segments: - if isinstance(s, str): - segments.append(s) - angles.append(None) - elif isinstance(s, dict): - seg_name, seg_angle = list(s.items())[0] - segments.append(seg_name) - angles.append(seg_angle) - - return segments, angles