Source code for gerrychain.tree

import networkx as nx
from networkx.algorithms import tree

from .random import random
from collections import deque, namedtuple


def predecessors(h, root):
    return {a: b for a, b in nx.bfs_predecessors(h, root)}


def random_spanning_tree(graph):
    for edge in graph.edges:
        graph.edges[edge]["weight"] = random.random()

    spanning_tree = tree.maximum_spanning_tree(
        graph, algorithm="kruskal", weight="weight"
    )
    return spanning_tree


class PopulatedGraph:
    def __init__(self, graph, populations, ideal_pop, epsilon):
        self.graph = graph
        self.subsets = {node: {node} for node in graph}
        self.population = populations.copy()
        self.ideal_pop = ideal_pop
        self.epsilon = epsilon
        self._degrees = {node: graph.degree(node) for node in graph}

    def __iter__(self):
        return iter(self.graph)

    def degree(self, node):
        return self._degrees[node]

    def contract_node(self, node, parent):
        self.population[parent] += self.population[node]
        self.subsets[parent] |= self.subsets[node]
        self._degrees[parent] -= 1

    def has_ideal_population(self, node):
        return (
            abs(self.population[node] - self.ideal_pop) < self.epsilon * self.ideal_pop
        )


def contract_leaves_until_balanced_or_none(h, choice=random.choice):
    # this used to be greater than 2 but failed on small grids:(
    root = choice([x for x in h if h.degree(x) > 1])
    # BFS predecessors for iteratively contracting leaves
    pred = predecessors(h.graph, root)

    leaves = deque(x for x in h if h.degree(x) == 1)
    while len(leaves) > 0:
        leaf = leaves.popleft()
        if h.has_ideal_population(leaf):
            return h.subsets[leaf]
        # Contract the leaf:
        parent = pred[leaf]
        h.contract_node(leaf, parent)
        if h.degree(parent) == 1 and parent != root:
            leaves.append(parent)
    return None


Cut = namedtuple("Cut", "edge subset")


def find_balanced_edge_cuts(h, choice=random.choice):
    # this used to be greater than 2 but failed on small grids:(
    root = choice([x for x in h if h.degree(x) > 1])
    # BFS predecessors for iteratively contracting leaves
    pred = predecessors(h.graph, root)

    cuts = []
    leaves = deque(x for x in h if h.degree(x) == 1)
    while len(leaves) > 0:
        leaf = leaves.popleft()
        if h.has_ideal_population(leaf):
            cuts.append(Cut(edge=(leaf, pred[leaf]), subset=h.subsets[leaf].copy()))
        # Contract the leaf:
        parent = pred[leaf]
        h.contract_node(leaf, parent)
        if h.degree(parent) == 1 and parent != root:
            leaves.append(parent)
    return cuts


[docs]def bipartition_tree( graph, pop_col, pop_target, epsilon, node_repeats=1, spanning_tree=None, choice=random.choice, ): """This function finds a balanced 2 partition of a graph by drawing a spanning tree and finding an edge to cut that leaves at most an epsilon imbalance between the populations of the parts. If a root fails, new roots are tried until node_repeats in which case a new tree is drawn. Builds up a connected subgraph with a connected complement whose population is ``epsilon * pop_target`` away from ``pop_target``. Returns a subset of nodes of ``graph`` (whose induced subgraph is connected). The other part of the partition is the complement of this subset. :param graph: The graph to partition :param pop_col: The node attribute holding the population of each node :param pop_target: The target population for the returned subset of nodes :param epsilon: The allowable deviation from ``pop_target`` (as a percentage of ``pop_target``) for the subgraph's population :param node_repeats: A parameter for the algorithm: how many different choices of root to use before drawing a new spanning tree. :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing) :param choice: :func:`random.choice`. Can be substituted for testing. """ populations = {node: graph.nodes[node][pop_col] for node in graph} balanced_subtree = None if spanning_tree is None: spanning_tree = random_spanning_tree(graph) restarts = 0 while balanced_subtree is None: if restarts == node_repeats: spanning_tree = random_spanning_tree(graph) restarts = 0 h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) balanced_subtree = contract_leaves_until_balanced_or_none(h, choice=choice) restarts += 1 return balanced_subtree
[docs]def bipartition_tree_random( graph, pop_col, pop_target, epsilon, node_repeats=1, spanning_tree=None, choice=random.choice, ): """This is like :func:`bipartition_tree` except it chooses a random balanced cut, rather than the first cut it finds. This function finds a balanced 2 partition of a graph by drawing a spanning tree and finding an edge to cut that leaves at most an epsilon imbalance between the populations of the parts. If a root fails, new roots are tried until node_repeats in which case a new tree is drawn. Builds up a connected subgraph with a connected complement whose population is ``epsilon * pop_target`` away from ``pop_target``. Returns a subset of nodes of ``graph`` (whose induced subgraph is connected). The other part of the partition is the complement of this subset. :param graph: The graph to partition :param pop_col: The node attribute holding the population of each node :param pop_target: The target population for the returned subset of nodes :param epsilon: The allowable deviation from ``pop_target`` (as a percentage of ``pop_target``) for the subgraph's population :param node_repeats: A parameter for the algorithm: how many different choices of root to use before drawing a new spanning tree. :param spanning_tree: The spanning tree for the algorithm to use (used when the algorithm chooses a new root and for testing) :param choice: :func:`random.choice`. Can be substituted for testing. """ populations = {node: graph.nodes[node][pop_col] for node in graph} possible_cuts = [] if spanning_tree is None: spanning_tree = random_spanning_tree(graph) while len(possible_cuts) == 0: spanning_tree = random_spanning_tree(graph) h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon) possible_cuts = find_balanced_edge_cuts(h, choice=choice) return choice(possible_cuts).subset
[docs]def recursive_tree_part( graph, parts, pop_target, pop_col, epsilon, node_repeats=1, method=bipartition_tree ): """Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into ``len(parts)`` parts of population ``pop_target`` (within ``epsilon``). Can be used to generate initial seed plans or to implement ReCom-like "merge walk" proposals. :param graph: The graph :param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)`` :param pop_target: Target population for each part of the partition :param pop_col: Node attribute key holding population data :param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts of the partition can be :param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use. :return: New assignments for the nodes of ``graph``. :rtype: dict """ flips = {} remaining_nodes = set(graph.nodes) # We keep a running tally of deviation from ``epsilon`` at each partition # and use it to tighten the population constraints on a per-partition # basis such that every partition, including the last partition, has a # population within +/-``epsilon`` of the target population. # For instance, if district n's population exceeds the target by 2% # with a +/-2% epsilon, then district n+1's population should be between # 98% of the target population and the target population. debt = 0 for part in parts[:-1]: min_pop = max(pop_target * (1 - epsilon), pop_target * (1 - epsilon) - debt) max_pop = min(pop_target * (1 + epsilon), pop_target * (1 + epsilon) - debt) nodes = method( graph.subgraph(remaining_nodes), pop_col=pop_col, pop_target=(min_pop + max_pop) / 2, epsilon=(max_pop - min_pop) / (2 * pop_target), node_repeats=node_repeats, ) part_pop = 0 for node in nodes: flips[node] = part part_pop += graph.nodes[node][pop_col] debt += part_pop - pop_target remaining_nodes -= nodes # All of the remaining nodes go in the last part for node in remaining_nodes: flips[node] = parts[-1] return flips