diff --git a/src/pg_rad/objects/objects.py b/src/pg_rad/objects/objects.py index 16f03b9..a610d9e 100644 --- a/src/pg_rad/objects/objects.py +++ b/src/pg_rad/objects/objects.py @@ -1,36 +1,42 @@ -import math from typing import Self +import numpy as np + class BaseObject: def __init__( self, - x: float, - y: float, - z: float, + pos: tuple[float, float, float], name: str = "Unnamed object", color: str = 'grey'): """ A generic object. Args: - x (float): X coordinate. - y (float): Y coordinate. - z (float): Z coordinate. + pos (tuple[float, float, float]): Position vector (x,y,z). name (str, optional): Name for the object. Defaults to "Unnamed object". color (str, optional): Matplotlib compatible color string. Defaults to "red". """ - self.x = x - self.y = y - self.z = z + if len(pos) != 3: + raise ValueError("Position must be tuple of length 3 (x,y,z).") + self.pos = pos self.name = name self.color = color - def distance_to(self, other: Self) -> float: - return math.dist( - (self.x, self.y, self.z), - (other.x, other.y, other.z), - ) + def distance_to(self, other: Self | tuple) -> float: + if isinstance(other, tuple) and len(other) == 3: + r = np.linalg.norm( + np.subtract(self.pos, other) + ) + else: + try: + r = np.linalg.norm( + np.subtract(self.pos, other.pos) + ) + except AttributeError as e: + raise e("other must be an object in the world \ + or a position tuple (x,y,z).") + return r diff --git a/src/pg_rad/objects/sources.py b/src/pg_rad/objects/sources.py index 1bc3a2d..ee47509 100644 --- a/src/pg_rad/objects/sources.py +++ b/src/pg_rad/objects/sources.py @@ -11,9 +11,7 @@ class PointSource(BaseObject): def __init__( self, - x: float, - y: float, - z: float, + pos: tuple, activity: int, isotope: Isotope, name: str | None = None, @@ -21,9 +19,7 @@ class PointSource(BaseObject): """A point source. Args: - x (float): X coordinate. - y (float): Y coordinate. - z (float): Z coordinate. + pos (tuple): a position vector of length 3 (x,y,z). activity (int): Activity A in MBq. isotope (Isotope): The isotope. name (str | None, optional): Can give the source a unique name. @@ -40,11 +36,10 @@ class PointSource(BaseObject): if name is None: name = f"Source {self.id}" - super().__init__(x, y, z, name, color) + super().__init__(pos, name, color) self.activity = activity self.isotope = isotope - self.color = color logger.debug(f"Source created: {self.name}")