"""A collection of sampled parameters from a statistic distribution.
Model parameters are random variables that are sampled together in a joint distribution
using a graph of dependencies. For example, the functions::
f(a, b) = x
g(c) = y
h(x, y) = z
indicate that x depends on a and b, y depends on c, and z depends on x and y (and thus on a, b, and c
as well). These would form a graph that looks like::
a - \\
x -- \\
b - / \\
z
c -- y --- /
Within LightCurveLynx, variables are grouped into logical sets called nodes. The combination of node name
and variable name are used to indicate specific values, allowing us to use the same variable names
in multiple nodes. For example a node to generate samples from a Gaussian distribution may have internal
parameters called mean and scale that might take on different values depending on what the node is generating.
We could have one mean for an object's brightness and another for its positional relative to the center
of a host galaxy.
"""
import copy
import numpy as np
from astropy.io import ascii
from astropy.table import Table
[docs]
class GraphState:
"""A class to hold the state(s) of the each variable for one or more samples of the random
variables in the graph. Each entry is indexed by a combination of node's (unique) name and
variable's name. This allows nodes to have parameters with the same name, such as ra and dec.
Attributes
----------
states : dict
A dictionary of dictionaries mapping the node's string and variable name to
either a value or array of values for that parameters.
num_samples : int
A count of the number of samples stored in the GraphState. If num_samples > 1, then
all parameters are stored as arrays of length num_samples. If num_samples == 1, then
all parameters are stored as scalars.
num_parameters : int
The total number of parameters stored in a single sample within GraphState.
fixed_vars : dict
A dictionary mapping the node name to a set of the variable names that are fixed (not
changed by resampling) in this GraphState instance. This is used for presetting certain
parameters to fixed values (such as during automatic differentiation).
Nodes are only included if they have at least one fixed variable.
sample_offset : int
An optional offset to add to the graph state for any stateful nodes.
Default: 0
sample_idx : int or None
An optional index of the current sample within a larger set of samples. Only
used when we have extracted a single sample from a multi-sample GraphState.
Default: None
"""
def __init__(self, num_samples=1, *, sample_offset=0):
if num_samples < 1:
raise ValueError(
f"Invalid number of samples for GraphState ({num_samples}). Must be a positive integer."
)
[docs]
self.num_samples = num_samples
[docs]
self.num_parameters = 0
[docs]
self.sample_offset = sample_offset
[docs]
def __len__(self):
return self.num_parameters
[docs]
def __next__(self):
return next(self._iterate())
[docs]
def __iter__(self):
return self._iterate()
def _iterate(self):
"""Returns a single sliced state, which is a GraphState object
with num_samples==1 and all scalar values."""
for idx in range(self.num_samples):
yield self.extract_single_sample(idx)
[docs]
def __contains__(self, key):
"""Check if the GraphState contains an entry.
The key can be:
1) the name of a node (in which case we return True if the node exists),
2) the full name of a parameter (in which case we return True if the
combination of node and parameter exists), or
3) the name of a parameter in a GraphState with a single node (in which case we return True
if the parameter exists in that node).
Parameters
----------
key : str
The name of the entry to check.
"""
if key in self.states:
# Check if this is a node name.
return True
elif "." in key:
# Check if this is a full name in node.param format.
tokens = key.split(".")
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return tokens[0] in self.states and tokens[1] in self.states[tokens[0]]
elif len(self.states) == 1:
# Special case when we have only a single node stored in the graph state.
node_state = list(self.states.values())[0]
if key in node_state:
return True
else:
return False
[docs]
def __str__(self):
str_lines = []
for node_name, node_vars in self.states.items():
str_lines.append(f"{node_name}:")
for var_name, value in node_vars.items():
str_lines.append(f" {var_name}: {value}")
return "\n".join(str_lines)
[docs]
def __eq__(self, other):
if self.num_samples != other.num_samples:
return False
if len(self.states) != len(other.states):
return False
for node_name, node_params in self.states.items():
# Check that this node exists in both GraphStates and has the same number
# of parameters.
if node_name not in other.states:
return False
other_params = other.states[node_name]
if len(node_params) != len(other_params):
return False
# Check that the values of each parameter matches.
for var_name, var_value in node_params.items():
if var_name not in other_params:
return False
if not np.allclose(var_value, other_params[var_name]):
return False
# Finally check that the 'fixed' dictionary is the same.
return self.fixed_vars == other.fixed_vars
[docs]
def __getitem__(self, key):
"""Access an entry in the GraphState.
The key can be:
1) the name of a node (in which case we return that node's dictionary of
parameter_name -> value),
2) the full name of a parameter (in which case we return the values), or
3) the name of a parameter in a GraphState with a single node (in which case we
return that parameter's values).
Parameters
----------
key : str
The name of the entry to access.
"""
if key in self.states:
return self.states[key]
elif "." in key:
tokens = key.split(".")
if len(tokens) != 2:
raise KeyError(f"Invalid GraphState key: {key}")
return self.states[tokens[0]][tokens[1]]
elif len(self.states) == 1:
# Special case when we have only a single node stored
# in the graph state.
node_state = list(self.states.values())[0]
if key in node_state:
return node_state[key]
else:
raise KeyError(f"Unknown GraphState key: {key}")
[docs]
def copy(self):
"""Create a deep copy of the GraphState.
Returns
-------
GraphState
The copied GraphState.
"""
new_state = GraphState(num_samples=self.num_samples)
new_state.num_parameters = self.num_parameters
new_state.sample_offset = self.sample_offset
new_state.sample_idx = self.sample_idx
for node_name, node_vars in self.states.items():
new_state.states[node_name] = {}
for var_name, var_value in node_vars.items():
if self.num_samples == 1:
new_state.states[node_name][var_name] = var_value
else:
new_state.states[node_name][var_name] = var_value.copy()
# Copy over the sets of fixed variables. This is a single set of strings
# per node, so we just need to copy the sets.
for node_name, var_set in self.fixed_vars.items():
new_state.fixed_vars[node_name] = copy.deepcopy(var_set)
return new_state
@staticmethod
[docs]
def extended_param_name(node_name, param_name):
"""A helper function to create the full parameter name.
Parameters
----------
node_name : str
The name of the node.
param_name : str
The name of the parameter.
Returns
-------
extended : str
A name of the form {node_name}.{param_name}
"""
return f"{node_name}.{param_name}"
@classmethod
[docs]
def from_table(cls, input_table):
"""Create the GraphState from an AstroPy Table with columns for each parameter
and column names of the form '{node_name}.{param_name}'.
Parameters
----------
input_table : astropy.table.Table
The input table.
"""
num_samples = len(input_table)
result = GraphState(num_samples=num_samples)
for col in input_table.colnames:
components = col.split(".")
if len(components) != 2:
raise ValueError(
f"Invalid name for entry '{col}'. Entries should be of the form 'node_name.param_name'."
)
# If we only have a single value then store that value instead of the np array.
if num_samples == 1:
result.set(components[0], components[1], input_table[col].data[0])
else:
result.set(components[0], components[1], input_table[col].data)
return result
@classmethod
[docs]
def from_file(cls, filename):
"""Create the GraphState from a saved file.
Parameters
----------
filename : str or Path
The name of the file.
"""
data_table = ascii.read(filename, format="ecsv")
return GraphState.from_table(data_table)
@classmethod
[docs]
def from_dict(cls, data, num_samples=1):
"""Create a GraphState from either a flattened dictionary, where the keys of the
dictionary are {node_name}.{param_name}, or a nested dictionary, where
data[node_name][param_name] = value.
Parameters
----------
data : dict
The dictionary mapping the parameter identifier (node name and parameter name)
to their values.
num_samples : int
The number of samples.
Default: 1
Returns
-------
GraphState
The corresponding graph state.
"""
state = GraphState(num_samples=num_samples)
for id1, val1 in data.items():
if "." in id1:
# Handle the flattened array by splitting the key.
node_name, param_name = id1.split(".")
state.set(node_name, param_name, val1, force_copy=True, fixed=False)
elif isinstance(val1, dict):
# Handle the nested array by iterating over the second dictionary's entries.
for param_name, values in val1.items():
state.set(id1, param_name, values, force_copy=True, fixed=False)
else:
raise ValueError("Input dictionary must either be flattened or nested.")
return state
@classmethod
[docs]
def from_pyarrow_struct_array(cls, struct_array, num_samples=1):
"""Create a GraphState from a PyArrow StructArray with fields for each parameter
and field names of the form '{node_name}.{param_name}'.
Parameters
----------
struct_array : pyarrow.StructArray
The input StructArray.
num_samples : int
The number of samples.
Default: 1
"""
state = GraphState(num_samples=num_samples)
for col in struct_array.type:
components = col.name.split(".")
if len(components) != 2:
raise ValueError(
f"Invalid name for entry '{col.name}'. Entries should be of the form "
f"'node_name.param_name'."
)
# If we only have a single value then store that value instead of the np array.
if num_samples == 1:
state.set(components[0], components[1], struct_array.field(col.name)[0].as_py())
else:
state.set(components[0], components[1], struct_array.field(col.name).to_numpy())
return state
@classmethod
[docs]
def from_list(cls, data):
"""Concatenate a list of GraphStates or single state dictionaries into a single GraphState.
All the entries in the data must have the same set of parameters (keys).
Parameters
----------
data : list of GraphState or dict
A list of the individual GraphState information to combine.
Returns
-------
GraphState
The corresponding graph state.
"""
if len(data) == 0:
raise ValueError("Cannot concatenate an empty list.")
# Convert everything into GraphStates (if they are not already) and extract
# the basic information.
all_param_full_names = None
graph_states = []
total_samples = 0
for current in data:
if isinstance(current, dict):
current = GraphState.from_dict(current)
elif not isinstance(current, GraphState):
raise TypeError(f"Concatenate takes either GraphState or dict. Got {type(current)}")
# Check that this is either the first GraphState we have seen or has the same parameters
# as the earlier GraphStates we have seen.
current_full_names = set(current.get_all_params_names())
if all_param_full_names is None:
all_param_full_names = current_full_names
elif all_param_full_names != current_full_names:
raise ValueError(
f"The sets of parameters do not match. Expected {all_param_full_names}."
f"Received {current_full_names}."
)
total_samples += current.num_samples
graph_states.append(current)
# Allocate space for the concatenated states and fill that result.
result = GraphState(num_samples=total_samples)
for full_name in all_param_full_names:
node_name, param_name = full_name.split(".")
# Create a numpy array that is the concatenation of all the values
# from each of the GraphStates.
values_list = []
for current in graph_states:
values_list.append(np.atleast_1d(current[full_name]))
values = np.concatenate(values_list)
result.set(node_name, param_name, values, force_copy=False, fixed=False)
return result
[docs]
def get_all_params_names(self):
"""Get the full name of all the parameters.
Returns
-------
names : list
A list of all the parameter names.
"""
names = []
for node_name, params in self.states.items():
for param_name in params:
names.append(self.extended_param_name(node_name, param_name))
return names
[docs]
def get_node_state(self, node_name, sample_num=0):
"""Get a dictionary of all parameters local to the given node
for a single sample state.
Parameters
----------
node_name : str
The parent node whose variables to extract.
sample_num : int
The number of sample to extract.
Returns
-------
values : dict
A dictionary mapping the parameter name to its value.
"""
if node_name not in self.states:
raise KeyError(f"Node name '{node_name}' not found in GraphState.")
if sample_num < 0 or sample_num >= self.num_samples:
raise ValueError(f"Invalid index {sample_num} in GraphState with {self.num_samples} entries.")
if self.num_samples == 1:
values = self.states[node_name]
else:
values = {}
for var_name, val in self.states[node_name].items():
values[var_name] = val[sample_num]
return values
[docs]
def set(self, node_name, var_name, value, force_copy=False, fixed=False):
"""Set a (new) parameter's value(s) in the GraphState from a given constant value
or an array of length num_samples (to set all the values at once).
Parameters
----------
node_name : str
The parent node holding this variable.
var_name : str
The parameter's name.
value : any
The new value of the parameter.
force_copy : bool
Make a copy of data in an array. If set to False, this will link
to the array, saving memory and computation time.
Default: False
fixed : bool
Treat this parameter as fixed and do not change it during subsequent calls to set.
It is recommended not to manually set this to True as it can cause difficult to debug
issues. It is primarily intended for use in automatic differentiation.
Default: False
"""
# Check that the names do not use the separator value.
if "." in node_name or "." in var_name:
raise ValueError("GraphState names (node or variable) cannot contain the character '.'")
# Update the meta data. We don't add an entry to fixed_vars until we need it.
if node_name not in self.states:
self.states[node_name] = {}
if var_name not in self.states[node_name]:
self.num_parameters += 1
# Check if this parameter is fixed. If so, skip the set. We do this instead of raising
# an error, because fixed variables are used to preset values during automatic differentiation
# or similar processes. We want to keep the preset values and fill in the rest of the graph.
if node_name in self.fixed_vars and var_name in self.fixed_vars[node_name]:
return
# Set the actual values.
if self.num_samples == 1:
# If this GraphState holds only a single sample, set it from the given value.
self.states[node_name][var_name] = value
elif np.isscalar(value):
# If the value is a scalar, expand it to the correct number of samples.
self.states[node_name][var_name] = np.full(self.num_samples, value)
elif len(value) != self.num_samples:
raise ValueError(
f"Incompatible number of samples when setting GraphState for node={node_name}, "
f"variable={var_name}: {self.num_samples} vs {len(value)}."
)
elif force_copy:
self.states[node_name][var_name] = np.array(value.copy())
else:
self.states[node_name][var_name] = np.asarray(value)
# Mark the variable as fixed if needed. We create the entry for this node the
# first time we need it.
if fixed:
if node_name not in self.fixed_vars:
self.fixed_vars[node_name] = set()
self.fixed_vars[node_name].add(var_name)
[docs]
def update(self, inputs, force_copy=False, all_fixed=False):
"""Set multiple parameters' value in the GraphState from a GraphState or a
dictionary of the same form.
Note
----
The number of samples in input must either match the number of samples in the
current object or be 1.
Parameters
----------
inputs : GraphState or dict
Values to copy.
force_copy : bool
Make a copy of data in an array. If set to False, this will link
to the array, saving memory and computation time.
Default: False
all_fixed : bool
Treat all the parameters in inputs as fixed.
Default: False
Raises
------
ValueError
If the input an invalid number of samples.
"""
if isinstance(inputs, GraphState):
if self.num_samples != inputs.num_samples and inputs.num_samples != 1:
raise ValueError(
f"GraphStates must have the same number of samples. "
f"Received {self.num_samples} and {inputs.num_samples}."
)
new_states = inputs.states
else:
new_states = inputs
# Set the values one by one. The set function takes care of expanding
# any values that are constants (e.g. float or int) to match the correct
# number of samples.
for node_name, node_vars in new_states.items():
for var_name, value in node_vars.items():
self.set(node_name, var_name, value, force_copy=force_copy, fixed=all_fixed)
[docs]
def to_table(self):
"""Flatten the graph state to an AstroPy Table with columns for each parameter.
The column names are: {node_name}.{param_name}
Returns
-------
values : astropy.table.Table
The resulting Table.
"""
values = Table()
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
values[self.extended_param_name(node_name, param_name)] = np.array(param_value)
return values
[docs]
def to_dict(self):
"""Flatten the graph state to a dictionary with columns for each parameter.
The column names are: {node_name}.{param_name}
Returns
-------
values : dict
The resulting dictionary.
"""
values = {}
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
if self.num_samples == 1:
values[self.extended_param_name(node_name, param_name)] = param_value
else:
values[self.extended_param_name(node_name, param_name)] = list(param_value)
return values
[docs]
def to_pyarrow_struct_array(self):
"""Flatten the graph state to a PyArrow StructArray with fields for each parameter.
The column names are: {node_name}.{param_name}
Returns
-------
values : pyarrow.StructArray
The resulting StructArray.
"""
try:
import pyarrow as pa
except ImportError as err: # pragma: no cover
raise ImportError(
"PyArrow is required to convert the GraphState to a PyArrow StructArray. "
"Please install it with 'pip install pyarrow'."
) from err
names = []
arrays = []
for node_name, node_params in self.states.items():
for param_name, param_value in node_params.items():
full_name = self.extended_param_name(node_name, param_name)
names.append(full_name)
if self.num_samples == 1:
arrays.append(pa.array([param_value]))
elif np.ndim(param_value) < 2:
arrays.append(pa.array(param_value))
else:
inner_size = np.prod(param_value.shape[1:])
flat_arrow = pa.array(np.reshape(param_value, (-1,)))
list_arrow = pa.FixedSizeListArray.from_arrays(
values=flat_arrow,
list_size=inner_size,
)
arrays.append(list_arrow)
return pa.StructArray.from_arrays(arrays, names=names)
[docs]
def save_to_file(self, filename, overwrite=False):
"""Save the GraphState to a file.
Parameters
----------
filename : str
The name of the file to save.
overwrite : bool
Whether to overwrite an existing file.
Default: False
"""
data_table = self.to_table()
ascii.write(data_table, filename, format="ecsv", overwrite=overwrite)
[docs]
class DependencyGraph:
"""A class to hold the dependencies between parameters in a model. Used for
analysis, documentation, testing, and visualization of the model structure.
The full parameter names are in the same form used by GraphState.
Attributes
----------
all_params : set
A set of all (full) parameter names in the graph.
all_nodes : set
A set of all node names in the graph.
incoming : dict
A dictionary mapping each parameter to the set of parameters that it depends on
(the incoming edges).
outgoing : dict
A dictionary mapping each parameter to the set of parameters that depend on it
(the outgoing edges).
num_constants : int
The number of constant parameters in the graph.
"""
def __init__(self):
[docs]
self.all_params = set()
[docs]
def __len__(self):
return len(self.all_params)
[docs]
def __contains__(self, full_param_name):
return full_param_name in self.all_params
[docs]
def add_parameter(self, param_name, node_name=None):
"""Add a parameter to the dependency graph if it is not already present.
Parameters
----------
param_name : str
The name of the parameter to add.
node_name : str, optional
The name of the node holding this parameter. If provided, the full parameter
name will be in the same form used by GraphState for storage.
Default: None
"""
# If a node name is provided, create the expanded parameter name.
# Also add the node name to the set of all nodes.
if node_name is not None:
if node_name not in self.all_nodes:
self.all_nodes.add(node_name)
param_name = GraphState.extended_param_name(node_name, param_name)
# If we haven't seen the parameter before, add it to the graph.
if param_name not in self.all_params:
self.all_params.add(param_name)
self.incoming[param_name] = set()
self.outgoing[param_name] = set()
[docs]
def add_constant(self, value):
"""Add a constant parameter to the dependency graph.
Parameters
----------
value : any
The value of the constant.
Returns
-------
const_name : str
The name of the constant parameter added to the graph.
"""
const_name = f"const_{self.num_constants}={value}"
self.num_constants += 1
self.add_parameter(const_name)
return const_name
[docs]
def add_edge(self, from_param, to_param):
"""Add a directed edge to the dependency graph.
Parameters
----------
from_param : str
The name of the parameter that the edge is coming from (the dependency).
to_param : str
The name of the parameter that the edge is going to (the dependent).
"""
if from_param not in self.all_params or to_param not in self.all_params:
raise KeyError("Both parameters must be added to the graph before adding an edge.")
self.incoming[to_param].add(from_param)
self.outgoing[from_param].add(to_param)
[docs]
def build_subgraph(self, param_name, incoming=True, outgoing=True):
"""Get the DAG subgraph that contains this parameter. This can be:
1) the parameters on which this parameter depends (incoming=True, outgoing=False),
2) the parameters that depend on this parameter (incoming=False, outgoing=True), or
3) all parameters in the same connected component as this parameter
(incoming=True, outgoing=True).
Parameters
----------
param_name : str
The name of the parameter to get the subgraph for.
incoming : bool
If True, include the parameters that have incoming edges to nodes in the subgraph.
Default: True
outgoing : bool
If True, include the parameters that have outgoing edges from nodes in the subgraph.
Default: True
Returns
-------
subgraph : DependencyGraph
The resulting subgraph.
"""
if param_name not in self.all_params:
raise KeyError(f"Parameter '{param_name}' not found in the graph.")
# Use breadth-first search to find all the parameters that should be included in
# the subgraph. This includes all the parameters that this parameter depends on and,
# if deps_only is False, all the parameters that depend on this parameter.
subgraph = DependencyGraph()
to_visit = [param_name]
while to_visit:
current = to_visit.pop()
if current not in subgraph.all_params:
subgraph.add_parameter(current)
if incoming:
for dep in self.incoming[current]:
to_visit.append(dep)
if outgoing:
for dependent in self.outgoing[current]:
to_visit.append(dependent)
# Add all edges where both nodes are in the subgraph.
for param in subgraph.all_params:
for dep in self.incoming[param]:
if dep in subgraph.all_params:
subgraph.add_edge(dep, param)
return subgraph
[docs]
def build_connected_components(self):
"""Get the DAG subgraphs that are the connected components of the graph.
Returns
-------
components : list of DependencyGraph
The resulting subgraphs.
"""
components = []
visited = set()
for param in self.all_params:
if param not in visited:
component = self.build_subgraph(param, incoming=True, outgoing=True)
components.append(component)
visited.update(component.all_params)
return components
def _compute_depths_helper(self, param, depths):
"""Helper function to compute the depth of a parameter in the graph.
Parameters
----------
param : str
The name of the parameter to compute the depth for.
depths : dict
A dictionary mapping parameter names to their depth.
Returns
-------
depth : int
The depth of the parameter.
"""
if param in depths:
return depths[param]
elif len(self.incoming[param]) == 0:
# This is a root node.
depths[param] = 0
return 0
else:
# Compute the depth as one more than the maximum depth of the dependencies.
max_depth = max(self._compute_depths_helper(dep, depths) for dep in self.incoming[param])
depths[param] = max_depth + 1
return max_depth + 1
[docs]
def compute_depths(self):
"""Compute the depth of each parameter in the graph.
Note
----
This function is primarily used for visualization purposes.
Returns
-------
depths : dict
A dictionary mapping parameter names to their depth.
"""
depths = {}
for param in self.all_params:
if param not in depths:
depths[param] = self._compute_depths_helper(param, depths)
return depths
[docs]
def to_networkx(self):
"""Create a NetworkX graph from the dependency graph.
Returns
-------
graph : networkx.DiGraph
The resulting directed graph.
"""
try:
import networkx as nx
except ImportError as err: # pragma: no cover
raise ImportError(
"NetworkX is required to convert the dependency graph to a NetworkX graph. "
"Please install it with 'pip install networkx'."
) from err
# Compute a mapping from parameter names to their depths.
depths = self.compute_depths()
# Create the graph and add each node with its depth as an attribute.
graph = nx.DiGraph()
for param in depths:
graph.add_node(param, layer=depths[param])
# Add the edges. We do this after adding the nodes to ensure that all nodes are present.
for param in self.all_params:
for dep in self.incoming[param]:
graph.add_edge(dep, param)
return graph
def _make_readable_graph_labels(self):
"""Create a mapping from parameter names to their labels for visualization.
Returns
-------
labels : dict
A dictionary mapping parameter names to their labels.
"""
labels = {}
for param in self.all_params:
curr_label = param
# Remove the "const_NN=" prefix from constant parameters.
if "=" in param and param.startswith("const_"):
curr_label = curr_label.split("=")[1]
# Remove function node names and "function_node_result" from labels.
if ":" in param:
curr_label = "FN:" + curr_label.split(":")[-1]
if curr_label.endswith(".function_node_result"):
curr_label = curr_label[: -len(".function_node_result")]
labels[param] = curr_label
return labels
[docs]
def draw(self, param_name=None):
"""Draw the connected components of the graph using NetworkX and Matplotlib."""
try:
import matplotlib.pyplot as plt
import networkx as nx
except ImportError as err: # pragma: no cover
raise ImportError(
"NetworkX and Matplotlib are required to draw the graph. "
"Please install them with 'pip install networkx matplotlib'."
) from err
# If a parameter name is provided, check that it exists in the graph.
if param_name is not None and param_name not in self.all_params:
raise KeyError(f"Parameter '{param_name}' not found in the graph.")
components = self.build_connected_components()
if param_name is not None:
# If a parameter name is provided, only draw the component that contains that parameter.
components = [comp for comp in components if param_name in comp.all_params]
num_components = len(components)
_, ax = plt.subplots(num_components, 1, figsize=(12, 6 * num_components))
if num_components == 1:
ax = [ax]
for idx, component in enumerate(components):
labels = component._make_readable_graph_labels()
graph = component.to_networkx()
layout = nx.multipartite_layout(graph, subset_key="layer")
layout = nx.spring_layout(graph, pos=layout)
nx.draw(
graph,
layout,
arrows=True,
labels=labels,
edge_color="gray",
with_labels=True,
node_color="lightgray",
node_size=2000,
font_size=10,
font_color="black",
font_weight="bold",
ax=ax[idx],
)
plt.show()
[docs]
def transpose_dict_of_list(input_dict, num_elem):
"""Transpose a dictionary of iterables to a list of dictionaries.
Parameters
----------
input_dict : dict
A dictionary of iterables, each of which is length num_elem.
num_elem : int
The length of the iterables.
Returns
-------
output_list : list
A length num_elem list of dictionaries, each with the same keys mapping
to a single value.
Raises
------
ValueError
If any of the iterables have different lengths.
"""
if num_elem < 1:
raise ValueError(f"Trying to transpose a dictionary with {num_elem} elements")
output_list = [{} for _ in range(num_elem)]
for key, values in input_dict.items():
if len(values) != num_elem:
raise ValueError(f"Entry '{key}' has length {len(values)}. Expected {num_elem}.")
for i in range(num_elem):
output_list[i][key] = values[i]
return output_list