diff --git a/src/road_gen/main.py b/src/road_gen/main.py index 138e97e..86cd43f 100644 --- a/src/road_gen/main.py +++ b/src/road_gen/main.py @@ -4,9 +4,8 @@ from matplotlib import pyplot as plt from .generators.random_road_generator import RandomRoadGenerator from .generators.segmented_road_generator import SegmentedRoadGenerator - from .plotting.plot_road import plot_road - +from .prefabs import prefabs from .utils import export def add_common_args(parser): @@ -30,48 +29,60 @@ def main(): random_parser.add_argument("--straight_section_max_rel_size", type=float, required=False, help="The maximum size that straight section(s) can have relative to the total length of the path. Defaults to 0.1.") add_common_args(random_parser) + segment_parser = subparsers.add_parser("segments", help="Generate a road according to a list of segments.") + segment_parser.add_argument("--segments", nargs="+", type=str, required=True, help=f"List of segments. Choose from {str(prefabs.PREFABS.keys())}") + segment_parser.add_argument("--alpha", type=float, required=False, help="Dirichlet distribution concentration parameter. A high alpha distributes total length more evenly across segments.") + add_common_args(segment_parser) + args = parser.parse_args() - if args.method == "random": - try: - if not all(v > 0 for v in (args.length, args.ds, args.velocity)): - raise ValueError("Length, step size, and velocity must be positive values.") + try: + if not all(v > 0 for v in (args.length, args.ds, args.velocity)): + raise ValueError("Length, step size, and velocity must be positive values.") - init_args = { - "length": args.length, - "ds": args.ds, - "velocity": args.velocity, - } + init_args = { + "length": args.length, + "ds": args.ds, + "velocity": args.velocity, + } - if args.mu: - init_args["mu"] = args.mu - if args.g: - init_args["g"] = args.g - if args.seed: - init_args["seed"] = args.seed + if args.mu: + init_args["mu"] = args.mu + if args.g: + init_args["g"] = args.g + if args.seed: + init_args["seed"] = args.seed + generate_args = {} + if args.method == "random": generator = RandomRoadGenerator(**init_args) - - generate_args = {} - + if args.straight_section_prob: generate_args["straight_section_prob"] = float(args.straight_section_prob) if args.straight_section_max_rel_size: generate_args["straight_section_max_rel_size"] = float(args.straight_section_max_rel_size) - x, y = generator.generate(**generate_args) + elif args.method == "segments": + generator = SegmentedRoadGenerator(**init_args) - if args.save: - basename = str(generator.seed) - plot_road(x, y, generator, save = True, filename = basename+".jpg") - export.coords_to_json(x, y, filename = basename+".json") - export.params_to_json(generator, filename = basename+".params.json") - else: - plot_road(x, y, generator) + if args.alpha: + generate_args["alpha"] = args.alpha + + generate_args["segments"] = list(args.segments) + + x, y = generator.generate(**generate_args) + + if args.save: + basename = str(generator.seed) + plot_road(x, y, generator, save = True, filename = basename+".jpg") + export.coords_to_json(x, y, filename = basename+".json") + export.params_to_json(generator, filename = basename+".params.json") + else: + plot_road(x, y, generator) - except ValueError as e: - print(f"Error: {e}") - exit(1) + except ValueError as e: + print(f"Error: {e}") + exit(1) if __name__ == "__main__": main() \ No newline at end of file