Source code for src.graph

import copy

import numpy as np

from src.utils.math_utils import check_psd

[docs] class Node: """ Represents a multidimensional Gaussian state distribution in a graph-structured belief roadmap. Parameters ----------- mean: np.ndarray Mean of the Gaussian distribution covariance: np.ndarray Covariance of the Gaussian distribution, must be PSD is_start: bool True if start node, False otherwise is_goal: bool True if goal node, False otherwise """ def __init__(self, mean, covariance, is_start=False, is_goal=False): self.mean = mean self.covariance = covariance self.is_start = is_start self.is_goal = is_goal # Check to make sure covariance is PSD if covariance is not None: check_psd(covariance) def __eq__(self, other): if not isinstance(other, Node): return NotImplemented mean_match = np.allclose(self.mean, other.mean) if self.covariance is None and other.covariance is None: covariance_match = True elif self.covariance is None or other.covariance is None: covariance_match = False else: covariance_match = np.allclose(self.covariance, other.covariance) start_match = self.is_start == other.is_start goal_match = self.is_goal == other.is_goal return mean_match and covariance_match and start_match and goal_match def __hash__(self): covariance = tuple(map(tuple,self.covariance)) if self.covariance is not None else None return hash((tuple(self.mean.tolist()), covariance, self.is_start, self.is_goal))
[docs] class Edge: """ Represents a directed edge in a belief roadmap between two Gaussian state distributions. Each edge is associated with a feedback control policy and a discrete-time Gaussian state trajectory between its start and end nodes. Parameters ----------- start_node: Node Start node of the edge end_node: Node End node of the edge mean: np.ndarray Array of means of intermediate Gaussian states covariance: np.ndarray Array of covariances of intermediate Gaussian states ff_ctrl: np.ndarray Open-loop control between start and end nodes fb_ctrl: np.ndarray Feedback control gain between start and end nodes """ def __init__(self, start_node, end_node, mean, covariance, ff_ctrl, fb_ctrl): self.start_node = start_node self.end_node = end_node self.mean = mean self.covariance = covariance self.ff_ctrl = ff_ctrl self.fb_ctrl = fb_ctrl
[docs] class Graph: """ Represents a graph-structured belief roadmap, where the nodes in the roadmap are Gaussian distributions in the state space, and the directed edges in the roadmap are control policies steering between nodes. Parameters ------------ nodes: set Set of Node objects representing nodes edges: set Set of Edge objects representing edges """ def __init__(self, nodes, edges): self.nodes = nodes self.edges = edges
[docs] def add_node(self, node): """ Add node to the graph. Parameters ----------- node: Node Node to be added to the graph. """ self.nodes.add(node)
[docs] def add_edge(self, edge): """ Add edge to the graph, assuming the start and end nodes of the edge are already in the graph. Parameters ------------ edge: Edge Edge to be added to the graph. """ if edge.start_node not in self.nodes: raise Exception("Edge starts at a node that isn't in the graph") if edge.end_node not in self.nodes: raise Exception("Edge ends at a node that isn't in graph") self.edges.add(edge)
[docs] def get_start_node(self): """ Get the start node in the graph. Throws an error unless exactly one start node is in the graph. Returns ---------- start_node: Node Start node of the graph """ start_nodes = [] for node in self.nodes: if node.is_start: start_nodes.append(node) if len(start_nodes) == 0: raise Exception("No start node found") if len(start_nodes) > 1: raise Exception("Multiple start nodes found") return start_nodes[0]
[docs] def get_goal_node(self): """ Get the goal node in the graph. Throws an error unless exactly one goal node is in the graph. Returns ---------- goal_node: Node Goal node of the graph """ goal_nodes = [] for node in self.nodes: if node.is_goal: goal_nodes.append(node) if len(goal_nodes) == 0: raise Exception("No goal node found") if len(goal_nodes) > 1: raise Exception("Multiple goal nodes found") return goal_nodes[0]
[docs] def get_parent(self, node): """ Get the parent of a node in the graph. Throws an error if the node isn't in the graph. Parameters ------------ node: Node Child node to look up parent parent: Node Parent of child node (None if child is start) """ if node not in self.nodes: raise Exception("Node is not in graph") for edge in self.edges: if edge.end_node == node: return edge.start_node if node.is_start: #OK to have no parent return None raise Exception("Parent not found")
[docs] def look_up_by_mean(self, node_mean): """ Look up a node in a graph by its mean. Throws an error unless exactly one node with the given mean is in the graph. Parameters ----------- node_mean: np.ndarray Mean of node to look up node: Node Node in graph with node_mean as its mean """ matching_nodes = [] for node in self.nodes: if np.allclose(node.mean, node_mean): matching_nodes.append(node) if len(matching_nodes) == 0: raise Exception("No matching nodes") if len(matching_nodes) > 1: raise Exception("Multiple nodes match") return matching_nodes[0]
[docs] def get_children(self, node): """ Get children of a node in the graph. Parameters ----------- node: Node Node to look up children Returns --------- children: set Set of child nodes of node (can be empty) """ if node not in self.nodes: raise Exception("Node is not in graph") children = set() for edge in self.edges: if edge.start_node == node: children.add(edge.end_node) return children
[docs] def get_ancestors(self, node): """ Get ancestors of a node in the graph, recursively, all the way back to the start node. Parameters ----------- node: Node Node to look up ancestors Returns -------- ancestors: set Ancestors of node """ ancestors = set() parent = self.get_parent(node) while parent is not None: ancestors.add(parent) parent = self.get_parent(parent) return ancestors
[docs] def get_descendants(self, node): """ Get descendants of a node in the graph, recursively, all the way down to leaf nodes. Parameters ------------ node: Node Node to find descendants of Returns -------- descendants: set Set of descendants of node """ descendants = set() children = self.get_children(node) while len(children) > 0: descendants = descendants.union(children) new_children = set() for child in children: grandchildren = self.get_children(child) new_children = new_children.union(grandchildren) children = copy.deepcopy(new_children) return descendants
[docs] def trim(self): """ Return minimal subgraph which connects the start node to the goal node. Assumes a path exists from the start to the goal. Returns ------------ trimmed_graph: Graph Minimal subgraph connecting start to goal """ current_node = -1 relevant_nodes = [] relevant_edges = [] for edge in self.edges: if edge.end_node.is_goal: current_node = edge.start_node relevant_edges.append(edge) relevant_nodes.append(edge.start_node) relevant_nodes.append(edge.end_node) while not current_node.is_start: for edge in self.edges: if edge.end_node == current_node: current_node = edge.start_node relevant_edges.append(edge) relevant_nodes.append(edge.start_node) break return Graph(set(relevant_nodes), set(relevant_edges))
[docs] def get_plan_to_goal(self): """ Trace edges from start to goal, and extract planned control and state. Can probably be refactored to traverse over trimmed graph only. Returns --------- u_traj: list List of planned open-loop control for each edge from start to goal K_traj: list List of planned feedback control for each edge x_traj: list List of mean state trajectory for each edge P_traj: list List of state covariance for each edge """ current_node = -1 u_traj = [] K_traj = [] P_traj = [] x_mean_traj = [] goal_node = -1 for node in self.nodes: if node.is_goal: goal_node = node for edge in self.edges: if edge.end_node.is_goal: u_traj.append(edge.ff_ctrl) K_traj.append(edge.fb_ctrl) P_traj.append(edge.covariance) x_mean_traj.append(edge.mean) current_node = edge.start_node while not current_node.is_start: for edge in self.edges: if edge.end_node == current_node: u_traj.insert(0, edge.ff_ctrl) K_traj.insert(0, edge.fb_ctrl) P_traj.insert(0, edge.covariance) x_mean_traj.insert(0, edge.mean) current_node = edge.start_node break return u_traj, K_traj, x_mean_traj, P_traj
[docs] def deepcopy(self): """ Deepcopy graph to a new Graph object. Returns -------- graph_copy: Graph Deepcopy of self """ node_copy = copy.deepcopy(self.nodes) edge_copy = copy.deepcopy(self.edges) return Graph(node_copy, edge_copy)