Source code for gerrychain.updaters.county_splits

import collections
from enum import Enum
from typing import Callable, Dict


CountyInfo = collections.namedtuple("CountyInfo", "split nodes contains")
"""
A named tuple to store county split information.

:param split: The county split status. Makes use of
    :class:`.CountySplit` enum to compute.
:type split: int
:param nodes: The nodes that are contained in the county.
:type nodes: List
:param contains: The assignment IDs that are contained in the county.
:type contains: Set
"""


[docs]class CountySplit(Enum): """ Enum to track county splits in a partition. :cvar NOT_SPLIT: The county is not split. :cvar NEW_SPLIT: The county is split in the current partition. :cvar OLD_SPLIT: The county is split in the parent partition. """ NOT_SPLIT = 0 NEW_SPLIT = 1 OLD_SPLIT = 2
[docs]def county_splits(partition_name: str, county_field_name: str) -> Callable: """ Update that allows for the tracking of county splits. :param partition_name: Name that the :class:`.Partition` instance will store. :type partition_name: str :param county_field_name: Name of county ID field on the graph. :type county_field_name: str :returns: The tracked data is a dictionary keyed on the county ID. The stored values are tuples of the form `(split, nodes, seen)`. `split` is a :class:`.CountySplit` enum, `nodes` is a list of node IDs, and `seen` is a list of assignment IDs that are contained in the county. :rtype: Callable """ def _get_county_splits(partition): return compute_county_splits(partition, county_field_name, partition_name) return _get_county_splits
[docs]def compute_county_splits( partition, county_field: str, partition_field: str ) -> Dict[str, CountyInfo]: """ Track nodes in counties and information about their splitting. :param partition: The partition object to compute county splits for. :type partition: :class:`~gerrychain.partition.Partition` :param county_field: Name of county ID field on the graph. :type county_field: str :param partition_field: Name of the attribute in the graph that stores the partition information. The county split information will be computed with respect to this division of the graph. :type partition_field: str :returns: A dict containing the information on how counties changed between the parent and child partitions. If there is no parent partition, then only the OLD_SPLIT and NOT_SPLIT values will be used. :rtype: Dict[str, CountyInfo] """ # Create the initial county data containers. if not partition.parent: county_dict = dict() for node in partition.graph.node_indices: county = partition.graph.lookup(node, county_field) if county in county_dict: split, nodes, seen = county_dict[county] else: split, nodes, seen = CountySplit.NOT_SPLIT, [], set() nodes.append(node) seen.update(set([partition.assignment.mapping[node]])) if len(seen) > 1: split = CountySplit.OLD_SPLIT county_dict[county] = CountyInfo(split, nodes, seen) return county_dict new_county_dict = dict() parent = partition.parent for county, county_info in parent[partition_field].items(): seen = set(partition.assignment.mapping[node] for node in county_info.nodes) split = CountySplit.NOT_SPLIT if len(seen) > 1: if county_info.split != CountySplit.OLD_SPLIT: split = CountySplit.NEW_SPLIT else: split = CountySplit.OLD_SPLIT new_county_dict[county] = CountyInfo(split, county_info.nodes, seen) return new_county_dict
[docs]def tally_region_splits(reg_attr_lst): """ A naive updater for tallying the number of times a region attribute is split. for each region attribute in reg_attr_lst. :param reg_attr_lst: A list of region names to tally splits for. :type reg_attr_lst: List[str] :returns: A function that takes a partition and returns a dictionary which maps the region name to the number of times that it is split in a a particular partition. :rtype: Callable """ def _get_splits(partition): nonlocal reg_attr_lst if "cut_edges" not in partition.updaters: raise ValueError("The cut_edges updater must be attached to the partition") return { reg_attr: total_reg_splits(partition, reg_attr) for reg_attr in reg_attr_lst } return _get_splits
[docs]def total_reg_splits(partition, reg_attr): """Returns the total number of times that reg_attr is split in the partition.""" all_region_names = set( partition.graph.nodes[node][reg_attr] for node in partition.graph.nodes ) split = {name: 0 for name in all_region_names} # Require that the cut_edges updater is attached to the partition for node1, node2 in partition["cut_edges"]: if ( partition.assignment[node1] != partition.assignment[node2] and partition.graph.nodes[node1][reg_attr] == partition.graph.nodes[node2][reg_attr] ): split[partition.graph.nodes[node1][reg_attr]] += 1 split[partition.graph.nodes[node2][reg_attr]] += 1 return sum(1 for value in split.values() if value > 0)