mirror of
https://github.com/pim-n/pg-rad
synced 2026-03-11 19:58:11 +01:00
Improve PEP8 adherance using flake8 linter
This commit is contained in:
@ -6,6 +6,7 @@ from pg_rad.exceptions import DataLoadError, InvalidCSVError
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_data(filename: str) -> pd.DataFrame:
|
def load_data(filename: str) -> pd.DataFrame:
|
||||||
logger.debug(f"Attempting to load file: {filename}")
|
logger.debug(f"Attempting to load file: {filename}")
|
||||||
|
|
||||||
@ -25,4 +26,4 @@ def load_data(filename: str) -> pd.DataFrame:
|
|||||||
raise DataLoadError("Unexpected error while loading data") from e
|
raise DataLoadError("Unexpected error while loading data") from e
|
||||||
|
|
||||||
logger.debug(f"File loaded: {filename}")
|
logger.debug(f"File loaded: {filename}")
|
||||||
return df
|
return df
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
class ConvergenceError(Exception):
|
class ConvergenceError(Exception):
|
||||||
"""Raised when an algorithm fails to converge."""
|
"""Raised when an algorithm fails to converge."""
|
||||||
|
|
||||||
|
|
||||||
class DataLoadError(Exception):
|
class DataLoadError(Exception):
|
||||||
"""Base class for data loading errors."""
|
"""Base class for data loading errors."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidCSVError(DataLoadError):
|
class InvalidCSVError(DataLoadError):
|
||||||
"""Raised when a file is not a valid CSV."""
|
"""Raised when a file is not a valid CSV."""
|
||||||
|
|||||||
@ -5,7 +5,7 @@ class Isotope:
|
|||||||
name (str): Full name (e.g. Caesium-137).
|
name (str): Full name (e.g. Caesium-137).
|
||||||
E (float): Energy of the primary gamma in keV.
|
E (float): Energy of the primary gamma in keV.
|
||||||
b (float): Branching ratio for the gamma at energy E.
|
b (float): Branching ratio for the gamma at energy E.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@ -20,4 +20,4 @@ class Isotope:
|
|||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.E = E
|
self.E = E
|
||||||
self.b = b
|
self.b = b
|
||||||
|
|||||||
@ -9,26 +9,29 @@ from pg_rad.objects import PointSource
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Landscape:
|
class Landscape:
|
||||||
"""A generic Landscape that can contain a Path and sources.
|
"""A generic Landscape that can contain a Path and sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
air_density (float, optional): Air density in kg / m^3. Defaults to 1.243.
|
air_density (float, optional): Air density, kg/m^3. Defaults to 1.243.
|
||||||
size (int | tuple[int, int, int], optional): Size of the world. Defaults to 500.
|
size (int | tuple[int, int, int], optional): Size of the world.
|
||||||
scale (str, optional): The scale of the size argument passed. Defaults to 'meters'.
|
Defaults to 500.
|
||||||
"""
|
scale (str, optional): The scale of the size argument passed.
|
||||||
|
Defaults to 'meters'.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
air_density: float = 1.243,
|
air_density: float = 1.243,
|
||||||
size: int | tuple[int, int, int] = 500,
|
size: int | tuple[int, int, int] = 500,
|
||||||
scale = 'meters',
|
scale: str = 'meters'
|
||||||
):
|
):
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
self.world = np.zeros((size, size, size))
|
self.world = np.zeros((size, size, size))
|
||||||
elif isinstance(size, tuple) and len(size) == 3:
|
elif isinstance(size, tuple) and len(size) == 3:
|
||||||
self.world = np.zeros(size)
|
self.world = np.zeros(size)
|
||||||
else:
|
else:
|
||||||
raise TypeError("size must be an integer or a tuple of 3 integers.")
|
raise TypeError("size must be integer or a tuple of 3 integers.")
|
||||||
|
|
||||||
self.air_density = air_density
|
self.air_density = air_density
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -36,8 +39,8 @@ class Landscape:
|
|||||||
self.path: Path = None
|
self.path: Path = None
|
||||||
self.sources: list[PointSource] = []
|
self.sources: list[PointSource] = []
|
||||||
logger.debug("Landscape initialized.")
|
logger.debug("Landscape initialized.")
|
||||||
|
|
||||||
def plot(self, z = 0):
|
def plot(self, z: float | int = 0):
|
||||||
"""Plot a slice of the world at a height `z`.
|
"""Plot a slice of the world at a height `z`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -45,7 +48,7 @@ class Landscape:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
fig, ax: Matplotlib figure objects.
|
fig, ax: Matplotlib figure objects.
|
||||||
"""
|
"""
|
||||||
x_lim, y_lim, _ = self.world.shape
|
x_lim, y_lim, _ = self.world.shape
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
@ -54,9 +57,9 @@ class Landscape:
|
|||||||
ax.set_xlabel(f"X [{self.scale}]")
|
ax.set_xlabel(f"X [{self.scale}]")
|
||||||
ax.set_ylabel(f"Y [{self.scale}]")
|
ax.set_ylabel(f"Y [{self.scale}]")
|
||||||
|
|
||||||
if not self.path == None:
|
if self.path is not None:
|
||||||
ax.plot(self.path.x_list, self.path.y_list, 'bo-')
|
ax.plot(self.path.x_list, self.path.y_list, 'bo-')
|
||||||
|
|
||||||
for s in self.sources:
|
for s in self.sources:
|
||||||
if np.isclose(s.z, z):
|
if np.isclose(s.z, z):
|
||||||
dot = Circle(
|
dot = Circle(
|
||||||
@ -78,19 +81,19 @@ class Landscape:
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax.add_patch(dot)
|
ax.add_patch(dot)
|
||||||
|
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
def add_sources(self, *sources: PointSource):
|
def add_sources(self, *sources: PointSource):
|
||||||
"""Add one or more point sources to the world.
|
"""Add one or more point sources to the world.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*sources (pg_rad.sources.PointSource): One or more sources, passed as
|
*sources (pg_rad.sources.PointSource): One or more sources,
|
||||||
Source1, Source2, ...
|
passed as Source1, Source2, ...
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the source is outside the boundaries of the landscape.
|
ValueError: If the source is outside the boundaries of the
|
||||||
"""
|
landscape.
|
||||||
|
"""
|
||||||
max_x, max_y, max_z = self.world.shape[:3]
|
max_x, max_y, max_z = self.world.shape[:3]
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
@ -108,24 +111,23 @@ class Landscape:
|
|||||||
Set the path in the landscape.
|
Set the path in the landscape.
|
||||||
"""
|
"""
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
def create_landscape_from_path(path: Path, max_z = 500):
|
|
||||||
|
def create_landscape_from_path(path: Path, max_z: float | int = 50):
|
||||||
"""Generate a landscape from a path, using its dimensions to determine
|
"""Generate a landscape from a path, using its dimensions to determine
|
||||||
the size of the landscape.
|
the size of the landscape.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (Path): A Path object describing the trajectory.
|
path (Path): A Path object describing the trajectory.
|
||||||
max_z (int, optional): Height of the world. Defaults to 500 meters.
|
max_z (int, optional): Height of the world. Defaults to 50 meters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
landscape (pg_rad.landscape.Landscape): A landscape with dimensions based on the provided Path.
|
landscape (pg_rad.landscape.Landscape): A landscape with dimensions
|
||||||
"""
|
based on the provided Path.
|
||||||
|
"""
|
||||||
max_x = np.ceil(max(path.x_list))
|
max_x = np.ceil(max(path.x_list))
|
||||||
max_y = np.ceil(max(path.y_list))
|
max_y = np.ceil(max(path.y_list))
|
||||||
|
|
||||||
landscape = Landscape(
|
landscape = Landscape(size=(max_x, max_y, max_z))
|
||||||
size = (max_x, max_y, max_z)
|
|
||||||
)
|
|
||||||
|
|
||||||
landscape.path = path
|
landscape.path = path
|
||||||
return landscape
|
return landscape
|
||||||
|
|||||||
@ -3,18 +3,19 @@ import pathlib
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(log_level: str = "WARNING"):
|
def setup_logger(log_level: str = "WARNING"):
|
||||||
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
if not log_level in levels:
|
if log_level not in levels:
|
||||||
raise ValueError(f"Log level must be one of {levels}.")
|
raise ValueError(f"Log level must be one of {levels}.")
|
||||||
|
|
||||||
base_dir = pathlib.Path(__file__).resolve().parent
|
base_dir = pathlib.Path(__file__).resolve().parent
|
||||||
config_file = base_dir / "configs" / "logging.yml"
|
config_file = base_dir / "configs" / "logging.yml"
|
||||||
|
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
config["loggers"]["root"]["level"] = log_level
|
config["loggers"]["root"]["level"] = log_level
|
||||||
|
|
||||||
logging.config.dictConfig(config)
|
logging.config.dictConfig(config)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
|
||||||
class BaseObject:
|
class BaseObject:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -16,18 +17,20 @@ class BaseObject:
|
|||||||
x (float): X coordinate.
|
x (float): X coordinate.
|
||||||
y (float): Y coordinate.
|
y (float): Y coordinate.
|
||||||
z (float): Z coordinate.
|
z (float): Z coordinate.
|
||||||
name (str, optional): Name for the object. Defaults to "Unnamed object".
|
name (str, optional): Name for the object.
|
||||||
color (str, optional): Matplotlib compatible color string. Defaults to "red".
|
Defaults to "Unnamed object".
|
||||||
|
color (str, optional): Matplotlib compatible color string.
|
||||||
|
Defaults to "red".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.z = z
|
self.z = z
|
||||||
self.name = name
|
self.name = name
|
||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
def distance_to(self, other: Self) -> float:
|
def distance_to(self, other: Self) -> float:
|
||||||
return math.dist(
|
return math.dist(
|
||||||
(self.x, self.y, self.z),
|
(self.x, self.y, self.z),
|
||||||
(other.x, other.y, other.z),
|
(other.x, other.y, other.z),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,8 +5,10 @@ from pg_rad.isotopes import Isotope
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PointSource(BaseObject):
|
class PointSource(BaseObject):
|
||||||
_id_counter = 1
|
_id_counter = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
x: float,
|
x: float,
|
||||||
@ -27,9 +29,10 @@ class PointSource(BaseObject):
|
|||||||
name (str | None, optional): Can give the source a unique name.
|
name (str | None, optional): Can give the source a unique name.
|
||||||
Defaults to None, making the name sequential.
|
Defaults to None, making the name sequential.
|
||||||
(Source-1, Source-2, etc.).
|
(Source-1, Source-2, etc.).
|
||||||
color (str, optional): Matplotlib compatible color string. Defaults to "red".
|
color (str, optional): Matplotlib compatible color string.
|
||||||
|
Defaults to "red".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.id = PointSource._id_counter
|
self.id = PointSource._id_counter
|
||||||
PointSource._id_counter += 1
|
PointSource._id_counter += 1
|
||||||
|
|
||||||
@ -42,8 +45,13 @@ class PointSource(BaseObject):
|
|||||||
self.activity = activity
|
self.activity = activity
|
||||||
self.isotope = isotope
|
self.isotope = isotope
|
||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
logger.debug(f"Source created: {self.name}")
|
logger.debug(f"Source created: {self.name}")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"PointSource(name={self.name}, pos={(self.x, self.y, self.z)}, isotope={self.isotope.name}, A={self.activity} MBq)"
|
repr_str = (f"PointSource(name={self.name}, "
|
||||||
|
+ f"pos={(self.x, self.y, self.z)}, "
|
||||||
|
+ f"A={self.activity} MBq), "
|
||||||
|
+ f"isotope={self.isotope.name}.")
|
||||||
|
|
||||||
|
return repr_str
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from pg_rad.exceptions import ConvergenceError
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PathSegment:
|
class PathSegment:
|
||||||
def __init__(self, a: tuple[float, float], b: tuple[float, float]):
|
def __init__(self, a: tuple[float, float], b: tuple[float, float]):
|
||||||
"""A straight Segment of a Path, from (x_a, y_a) to (x_b, y_b).
|
"""A straight Segment of a Path, from (x_a, y_a) to (x_b, y_b).
|
||||||
@ -18,18 +19,18 @@ class PathSegment:
|
|||||||
Args:
|
Args:
|
||||||
a (tuple[float, float]): The starting point (x_a, y_a).
|
a (tuple[float, float]): The starting point (x_a, y_a).
|
||||||
b (tuple[float, float]): The final point (x_b, y_b).
|
b (tuple[float, float]): The final point (x_b, y_b).
|
||||||
"""
|
"""
|
||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
|
|
||||||
def get_length(self) -> float:
|
def get_length(self) -> float:
|
||||||
return math.dist(self.a, self.b)
|
return math.dist(self.a, self.b)
|
||||||
|
|
||||||
length = property(get_length)
|
length = property(get_length)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(f"({self.a}, {self.b})")
|
return str(f"({self.a}, {self.b})")
|
||||||
|
|
||||||
def __getitem__(self, index) -> float:
|
def __getitem__(self, index) -> float:
|
||||||
if index == 0:
|
if index == 0:
|
||||||
return self.a
|
return self.a
|
||||||
@ -38,26 +39,30 @@ class PathSegment:
|
|||||||
else:
|
else:
|
||||||
raise IndexError
|
raise IndexError
|
||||||
|
|
||||||
|
|
||||||
class Path:
|
class Path:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
coord_list: Sequence[tuple[float, float]],
|
coord_list: Sequence[tuple[float, float]],
|
||||||
z: float = 0,
|
z: float = 0,
|
||||||
path_simplify = False
|
path_simplify: bool = False
|
||||||
):
|
):
|
||||||
"""Construct a path of sequences based on a list of coordinates.
|
"""Construct a path of sequences based on a list of coordinates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
coord_list (Sequence[tuple[float, float]]): List of x,y coordinates.
|
coord_list (Sequence[tuple[float, float]]): List of x,y
|
||||||
|
coordinates.
|
||||||
z (float, optional): Height of the path. Defaults to 0.
|
z (float, optional): Height of the path. Defaults to 0.
|
||||||
path_simplify (bool, optional): Whether to pg_rad.path.simplify_path(). Defaults to False.
|
path_simplify (bool, optional): Whether to
|
||||||
"""
|
pg_rad.path.simplify_path(). Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
if len(coord_list) < 2:
|
if len(coord_list) < 2:
|
||||||
raise ValueError("Must provide at least two coordinates as a list of tuples, e.g. [(x1, y1), (x2, y2)]")
|
raise ValueError("Must provide at least two coordinates as a \
|
||||||
|
of tuples, e.g. [(x1, y1), (x2, y2)]")
|
||||||
|
|
||||||
x, y = tuple(zip(*coord_list))
|
x, y = tuple(zip(*coord_list))
|
||||||
|
|
||||||
if path_simplify:
|
if path_simplify:
|
||||||
try:
|
try:
|
||||||
x, y = simplify_path(list(x), list(y))
|
x, y = simplify_path(list(x), list(y))
|
||||||
@ -69,7 +74,11 @@ class Path:
|
|||||||
|
|
||||||
coord_list = list(zip(x, y))
|
coord_list = list(zip(x, y))
|
||||||
|
|
||||||
self.segments = [PathSegment(i, ip1) for i, ip1 in zip(coord_list, coord_list[1:])]
|
self.segments = [
|
||||||
|
PathSegment(i, ip1)
|
||||||
|
for i, ip1 in
|
||||||
|
zip(coord_list, coord_list[1:])
|
||||||
|
]
|
||||||
|
|
||||||
self.z = z
|
self.z = z
|
||||||
|
|
||||||
@ -77,7 +86,7 @@ class Path:
|
|||||||
|
|
||||||
def get_length(self) -> float:
|
def get_length(self) -> float:
|
||||||
return sum([s.length for s in self.segments])
|
return sum([s.length for s in self.segments])
|
||||||
|
|
||||||
length = property(get_length)
|
length = property(get_length)
|
||||||
|
|
||||||
def __getitem__(self, index) -> PathSegment:
|
def __getitem__(self, index) -> PathSegment:
|
||||||
@ -85,40 +94,43 @@ class Path:
|
|||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str([str(s) for s in self.segments])
|
return str([str(s) for s in self.segments])
|
||||||
|
|
||||||
def plot(self, **kwargs):
|
def plot(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Plot the path using matplotlib.
|
Plot the path using matplotlib.
|
||||||
"""
|
"""
|
||||||
plt.plot(self.x_list, self.y_list, **kwargs)
|
plt.plot(self.x_list, self.y_list, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def simplify_path(
|
def simplify_path(
|
||||||
x: Sequence[float],
|
x: Sequence[float],
|
||||||
y: Sequence[float],
|
y: Sequence[float],
|
||||||
keep_endpoints_equal: bool = False,
|
keep_endpoints_equal: bool = False,
|
||||||
n_breakpoints: int = 3
|
n_breakpoints: int = 3
|
||||||
):
|
):
|
||||||
"""From full resolution x and y arrays, return a piecewise linearly approximated/simplified pair of x and y arrays.
|
"""From full resolution x and y arrays, return a piecewise linearly
|
||||||
|
approximated/simplified pair of x and y arrays.
|
||||||
|
|
||||||
This function uses the `piecewise_regression` package. From a full set of
|
This function uses the `piecewise_regression` package. From a full set of
|
||||||
coordinate pairs, the function fits linear sections, automatically finding
|
coordinate pairs, the function fits linear sections, automatically finding
|
||||||
the number of breakpoints and their positions.
|
the number of breakpoints and their positions.
|
||||||
|
|
||||||
On why the default value of n_breakpoints is 3, from the `piecewise_regression`
|
On why the default value of n_breakpoints is 3, from the
|
||||||
docs:
|
`piecewise_regression` docs:
|
||||||
"If you do not have (or do not want to use) initial guesses for the number
|
"If you do not have (or do not want to use) initial guesses for the number
|
||||||
of breakpoints, you can set it to n_breakpoints=3, and the algorithm will
|
of breakpoints, you can set it to n_breakpoints=3, and the algorithm will
|
||||||
randomly generate start_values. With a 50% chance, the bootstrap restarting
|
randomly generate start_values. With a 50% chance, the bootstrap restarting
|
||||||
algorithm will either use the best currently converged breakpoints or
|
algorithm will either use the best currently converged breakpoints or
|
||||||
randomly generate new start_values, escaping the local optima in two ways in
|
randomly generate new start_values, escaping the local optima in two ways
|
||||||
order to find better global optima."
|
in order to find better global optima."
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Sequence[float]): Full list of x coordinates.
|
x (Sequence[float]): Full list of x coordinates.
|
||||||
y (Sequence[float]): Full list of y coordinates.
|
y (Sequence[float]): Full list of y coordinates.
|
||||||
keep_endpoints_equal (bool, optional): Whether or not to force start
|
keep_endpoints_equal (bool, optional): Whether or not to force start
|
||||||
and end to be exactly equal to the original. This will worsen the linear
|
and end to be exactly equal to the original. This will worsen the
|
||||||
approximation at the beginning and end of path. Defaults to False.
|
linear approximation at the beginning and end of path. Defaults to
|
||||||
|
False.
|
||||||
n_breakpoints (int, optional): Number of breakpoints. Defaults to 3.
|
n_breakpoints (int, optional): Number of breakpoints. Defaults to 3.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -129,25 +141,27 @@ def simplify_path(
|
|||||||
ConvergenceError: If the fitting algorithm failed to simplify the path.
|
ConvergenceError: If the fitting algorithm failed to simplify the path.
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
Pilgrim, C., (2021). piecewise-regression (aka segmented regression) in Python. Journal of Open Source Software, 6(68), 3859, https://doi.org/10.21105/joss.03859.
|
Pilgrim, C., (2021). piecewise-regression (aka segmented regression)
|
||||||
|
in Python. Journal of Open Source Software,
|
||||||
|
6(68), 3859, https://doi.org/10.21105/joss.03859.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.debug(f"Attempting piecewise regression on path.")
|
logger.debug("Attempting piecewise regression on path.")
|
||||||
|
|
||||||
pw_fit = piecewise_regression.Fit(x, y, n_breakpoints=n_breakpoints)
|
pw_fit = piecewise_regression.Fit(x, y, n_breakpoints=n_breakpoints)
|
||||||
pw_res = pw_fit.get_results()
|
pw_res = pw_fit.get_results()
|
||||||
|
|
||||||
if pw_res == None:
|
if pw_res is None:
|
||||||
logger.warning("Piecewise regression failed to converge.")
|
logger.warning("Piecewise regression failed to converge.")
|
||||||
raise ConvergenceError("Piecewise regression failed to converge.")
|
raise ConvergenceError("Piecewise regression failed to converge.")
|
||||||
|
|
||||||
est = pw_res['estimates']
|
est = pw_res['estimates']
|
||||||
|
|
||||||
# extract and sort breakpoints
|
# extract and sort breakpoints
|
||||||
breakpoints_x = sorted(
|
breakpoints_x = sorted(
|
||||||
v['estimate'] for k, v in est.items() if k.startswith('breakpoint')
|
v['estimate'] for k, v in est.items() if k.startswith('breakpoint')
|
||||||
)
|
)
|
||||||
|
|
||||||
x_points = [x[0]] + breakpoints_x + [x[-1]]
|
x_points = [x[0]] + breakpoints_x + [x[-1]]
|
||||||
|
|
||||||
y_points = pw_fit.predict(x_points)
|
y_points = pw_fit.predict(x_points)
|
||||||
@ -158,28 +172,33 @@ def simplify_path(
|
|||||||
y_points[-1] = y[-1]
|
y_points[-1] = y[-1]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Piecewise regression reduced path from {len(x)-1} to {len(x_points)-1} segments."
|
f"Piecewise regression reduced path from \
|
||||||
|
{len(x)-1} to {len(x_points)-1} segments."
|
||||||
)
|
)
|
||||||
|
|
||||||
return x_points, y_points
|
return x_points, y_points
|
||||||
|
|
||||||
|
|
||||||
def path_from_RT90(
|
def path_from_RT90(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
east_col: str = "East",
|
east_col: str = "East",
|
||||||
north_col: str = "North",
|
north_col: str = "North",
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Construct a path from East and North formatted coordinates (RT90) in a Pandas DataFrame.
|
"""Construct a path from East and North formatted coordinates (RT90)
|
||||||
|
in a Pandas DataFrame.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df (pandas.DataFrame): DataFrame containing at least the two columns noted in the cols argument.
|
df (pandas.DataFrame): DataFrame containing at least the two columns
|
||||||
|
noted in the cols argument.
|
||||||
east_col (str): The column name for the East coordinates.
|
east_col (str): The column name for the East coordinates.
|
||||||
north_col (str): The column name for the North coordinates.
|
north_col (str): The column name for the North coordinates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path: A Path object built from the aquisition coordinates in the DataFrame.
|
Path: A Path object built from the aquisition coordinates in the
|
||||||
"""
|
DataFrame.
|
||||||
|
"""
|
||||||
|
|
||||||
east_arr = np.array(df[east_col]) - min(df[east_col])
|
east_arr = np.array(df[east_col]) - min(df[east_col])
|
||||||
north_arr = np.array(df[north_col]) - min(df[north_col])
|
north_arr = np.array(df[north_col]) - min(df[north_col])
|
||||||
|
|
||||||
@ -187,4 +206,4 @@ def path_from_RT90(
|
|||||||
|
|
||||||
path = Path(coord_pairs, **kwargs)
|
path = Path(coord_pairs, **kwargs)
|
||||||
logger.debug("Loaded path from provided RT90 coordinates.")
|
logger.debug("Loaded path from provided RT90 coordinates.")
|
||||||
return path
|
return path
|
||||||
|
|||||||
Reference in New Issue
Block a user