"""A module containing the data trace graph (DTG).
The data trace is a directed acyclic graph that traces the creation of the main data
entities:
- :class:`~dispel.data.core.Reading`,
- :class:`~dispel.data.core.Level`,
- :class:`~dispel.data.raw.RawDataSet`,
- :class:`~dispel.data.measures.MeasureValue`.
It links the creation of each entity with its creators.
"""
import warnings
from dataclasses import field
from functools import singledispatchmethod
from itertools import chain
from typing import Iterable, List, Optional, Union, cast
import networkx as nx
from networkx.classes.coreviews import MultiAdjacencyView
from networkx.classes.reportviews import NodeView
from dispel.data.core import EntityType, Reading
from dispel.data.flags import Flag, FlagMixIn, verify_flag_uniqueness
from dispel.data.levels import Level, LevelEpoch
from dispel.data.measures import MeasureValue
from dispel.data.raw import RawDataSet
from dispel.processing.core import (
FlagError,
ProcessingControlResult,
ProcessingResult,
ProcessingStep,
)
from dispel.processing.epochs import CreateLevelEpochStep
from dispel.processing.extract import ExtractStep
from dispel.processing.level import LevelProcessingControlResult
from dispel.processing.transform import ConcatenateLevels, TransformStep
from dispel.utils import multidispatchmethod, plural
[docs]
class NodeNotFound(Exception):
"""Exception raised when a node is not found in the data trace.
Parameters
----------
entity
The entity corresponding to the exception.
"""
[docs]
def __init__(self, entity: EntityType):
super().__init__(
f"No node corresponding to {entity=} was found in the data trace."
)
[docs]
class EdgeNotFound(Exception):
"""Exception raised when an edge is not found in the data trace.
Parameters
----------
source
The entity corresponding to the source node of the edge in question.
result
The entity corresponding to the result node of the edge in question.
"""
[docs]
def __init__(self, source: EntityType, result: EntityType):
super().__init__(
f"No edge corresponding to the nodes {source=} and {result=} in the data "
"trace graph."
)
[docs]
class MissingSourceNode(Exception):
"""Exception raised if the source node of a data trace is missing.
Parameters
----------
entity
The entity corresponding to the absent source node.
"""
[docs]
def __init__(self, entity: EntityType):
super().__init__(
f"No source node corresponding to {entity=} was found in the data trace "
f"graph."
)
[docs]
class NodeHasMultipleParents(Exception):
"""Exception raised if a node is to have multiple parents.
Parameters
----------
entity
The entity corresponding to the exception.
"""
[docs]
def __init__(self, entity: EntityType):
super().__init__(
f"The result node corresponding to {entity=} already exists in the data "
"trace and thus cannot have multiple parents."
)
[docs]
class GraphNoCycle(Exception):
"""Exception raised if a cycle is to be created in the data trace graph.
Parameters
----------
source
The source node that has caused the graph cycle.
result
The result node that has caused the graph cycle.
"""
[docs]
def __init__(self, source: EntityType, result: EntityType):
super().__init__(
f"A data trace cycle has been detected. The {source=} is already a "
f"predecessor of {result=}."
)
[docs]
class DataTrace:
"""A class representation of the data trace graph."""
[docs]
def __init__(self):
self._graph = nx.MultiDiGraph()
self._processing_step_count = 0
self.root: EntityType = field(init=False)
def __repr__(self):
return (
f"<DataTrace of {self.get_reading()}: ("
f'{plural("entity", self.entity_count, "entities")}, '
f'{plural("processing step", self.processing_step_count)})>'
)
[docs]
@singledispatchmethod
def populate(
self,
entity: Union[
Level,
Reading,
ProcessingResult,
ProcessingControlResult,
LevelProcessingControlResult,
],
) -> Optional[Exception]:
"""Populate the data trace with the provided entity.
If the provided entity is a :class:`~dispel.data.core.Reading` or a
:class:`~dispel.data.core.Level`, then the entity and its components are injected
as data traces.
If the provided entities are processing results, then the associated results are
added as data traces and also set inside the corresponding reading.
Parameters
----------
entity
The entity to injected inside the data trace.
Returns
-------
Optional[List[Exception]]
If existent, the list of exceptions to be raised at the end of the
processing.
Raises
------
NotImplementedError
If the provided entity type is not supported.
"""
raise NotImplementedError(
f"Unsupported type for data trace population: {type(entity)}"
)
@populate.register(Reading)
def _reading(self, entity: Reading):
"""Populate data trace with newly created reading."""
for measure in entity.measure_set.values():
self.add_trace(entity, cast(MeasureValue, measure))
for level in entity.levels:
self.add_trace(entity, level)
self.populate(level)
@populate.register(Level)
def _level(self, entity: Level):
"""Populate data trace with newly created level."""
for raw_data_set in entity.raw_data_sets:
self.add_trace(entity, raw_data_set)
for measure in entity.measure_set.values():
self.add_trace(entity, cast(MeasureValue, measure))
@populate.register(ProcessingResult)
def _processing_result(self, entity: ProcessingResult):
"""Populate data trace with processing result."""
# Setting result entity in reading
self.get_reading().set(entity.result, **entity.get_kwargs())
# Setting data traces in data trace graph
for source in entity.get_sources():
self.add_trace(source, entity.result, step=entity.step)
# Populate the data trace with the result if needed
if isinstance(entity.result, Level):
self.populate(entity.result)
self._processing_step_count += 1
@populate.register(ProcessingControlResult)
def _processing_control_result(
self, entity: ProcessingControlResult
) -> Optional[Exception]:
"""Handle the processing control results."""
if isinstance(entity.error, FlagError):
for target in entity.get_targets():
target.add_flag(entity.error.flag)
if entity.error_handling.should_raise:
return entity.error
return None
[docs]
@classmethod
def from_reading(cls, reading: Reading) -> "DataTrace":
"""Initialise a data trace graph from a reading.
Parameters
----------
reading
The reading associated to the data trace.
Returns
-------
DataTrace
The initialised data trace graph.
"""
data_trace = cls()
data_trace.root = reading
data_trace._graph.add_node(reading)
data_trace.populate(reading)
return data_trace
[docs]
def get_reading(self) -> Reading:
"""Get the reading associated with the data trace graph."""
if isinstance(entity := self.root, Reading):
return entity
raise ValueError(f"The root node entity is not a reading but a {type(entity)}")
[docs]
def nodes(self) -> NodeView:
"""Retrieve all the nodes inside the data trace graph.
Returns
-------
NodeView
The list of nodes inside the data trace graph.
"""
return self._graph.nodes
[docs]
def is_leaf(self, entity: EntityType) -> bool:
"""Determine whether the entity is a leaf."""
return self._graph.out_degree(entity) == 0
[docs]
def leaves(self) -> Iterable[EntityType]:
"""Retrieve all leaves of the data trace graph.
Returns
-------
Iterable[EntityType]
The leaf nodes of the data trace graph.
"""
return filter(self.is_leaf, self._graph.nodes)
@property
def entity_count(self) -> int:
"""Get the entity count in the data trace graph.
Returns
-------
int
The number of entities in the data trace graph.
"""
return self._graph.number_of_nodes()
@property
def leaf_count(self) -> int:
"""Get the leaf count of the data trace graph.
Returns
-------
int
The number of leaf nodes in the data trace graph.
"""
return len(list(self.leaves()))
[docs]
def has_node(self, entity: EntityType) -> bool:
"""Check whether a node exists in the data trace graph.
Parameters
----------
entity
The entity of the node whose existence in the data trace graph is to be
checked.
Returns
-------
bool
``True`` if the corresponding node exists inside the graph. ``False``
otherwise.
"""
return self._graph.has_node(entity)
[docs]
def parents(self, entity: EntityType) -> MultiAdjacencyView:
"""Get direct predecessors of an entity."""
return self._graph.pred[entity]
[docs]
def children(self, entity: EntityType) -> MultiAdjacencyView:
"""Get direct successors of an entity."""
return self._graph.succ[entity]
[docs]
@multidispatchmethod
def add_trace(self, source: EntityType, result: EntityType, step=None, **kwargs):
"""Add a single data trace to the graph.
Parameters
----------
source
The source entity that led to the creation of the result entity.
result
The result entity created from the source entity.
step
The processing step that has led to the creation of the result if
existent.
kwargs
Additional keywords arguments passed to downstream dispatchers.
Raises
------
NotImplementedError
If the source and/or result types are not supported.
"""
raise NotImplementedError("The two node types are not supported.")
def _add_trace(
self,
source: EntityType,
result: EntityType,
step: Optional[ProcessingStep] = None,
allow_multiple_parents: bool = True,
):
"""Add an edge to the data trace graph."""
if not self._graph.has_node(source):
raise MissingSourceNode(source)
if had_result := self._graph.has_node(result):
if not allow_multiple_parents:
raise NodeHasMultipleParents(result)
if step:
self._graph.add_edge(source, result, key=hash(step), step=step)
else:
self._graph.add_edge(source, result)
if had_result:
try:
nx.algorithms.find_cycle(self._graph, result)
raise GraphNoCycle(source, result)
except nx.NetworkXNoCycle:
pass
@add_trace.register(Reading, Level)
def _reading_level(self, source: Reading, result: Level):
self._add_trace(source, result, allow_multiple_parents=False)
@add_trace.register(Reading, MeasureValue)
def _reading_measure(self, source: Reading, result: MeasureValue):
self._add_trace(source, result, allow_multiple_parents=False)
@add_trace.register(Level, RawDataSet)
def _level_container(self, source: Level, result: RawDataSet):
self._add_trace(source, result, allow_multiple_parents=False)
@add_trace.register(Level, MeasureValue)
def _level_measure(
self, source: Level, result: MeasureValue, step: Optional[ProcessingStep] = None
):
self._add_trace(source, result, step, allow_multiple_parents=False)
@add_trace.register(Level, LevelEpoch)
def _level_epoch(self, source: Level, result: LevelEpoch, step: ProcessingStep):
self._add_trace(source, result, step, allow_multiple_parents=False)
@add_trace.register(MeasureValue, Level)
def _measure_level(self, source: MeasureValue, result: Level, step: ProcessingStep):
self._add_trace(source, result, step)
@add_trace.register(RawDataSet, RawDataSet)
def _transform(self, source: RawDataSet, result: RawDataSet, step: TransformStep):
self._add_trace(source, result, step)
@add_trace.register(RawDataSet, MeasureValue)
def _extract(self, source: RawDataSet, result: MeasureValue, step: ExtractStep):
self._add_trace(source, result, step)
@add_trace.register(RawDataSet, LevelEpoch)
def _create_level_epoch(
self, source: RawDataSet, result: LevelEpoch, step: CreateLevelEpochStep
):
self._add_trace(source, result, step)
@add_trace.register(LevelEpoch, MeasureValue)
def _extract_level_epoch_measure(
self, source: LevelEpoch, result: MeasureValue, step: ExtractStep
):
self._add_trace(source, result, step)
@add_trace.register(RawDataSet, Level)
def _concatenate(self, source: RawDataSet, result: Level, step: ConcatenateLevels):
self._add_trace(source, result, step)
@add_trace.register(MeasureValue, MeasureValue)
def _aggregate(
self, source: MeasureValue, result: MeasureValue, step: ProcessingStep
):
self._add_trace(source, result, step)
@property
def processing_step_count(self) -> int:
"""Get the processing steps count in the data trace graph.
Returns
-------
int
The number of different processing steps in the data trace graph.
"""
return self._processing_step_count
[docs]
def check_data_set_usage(self) -> List[RawDataSet]:
r"""Check the usage of the raw data sets inside the data trace graph.
It means checking if all leaf nodes are :class:`dispel.data.measures.MeasureValue`
and raise warnings if any are found.
Returns
-------
List[RawDataSet]
The list of unused data sets.
Warns
-----
UserWarning
If leaf entities do not correspond to measure values.
Raises
------
ValueError
If a leaf node if neither a measure value nor a raw data set.
"""
unused_data_sets: List[RawDataSet] = []
for node in self.leaves():
if isinstance(node, RawDataSet):
unused_data_sets.append(node)
warnings.warn(
"No measure has been created from the following entity " f"{node=}",
UserWarning,
)
elif not isinstance(node, MeasureValue):
raise ValueError(
"The created leaf nodes should either be raw data sets or measure "
f"values. Found {type(node)}"
)
return unused_data_sets
[docs]
def get_flags(self, entity: EntityType) -> List[Flag]:
"""Get all flags related to the entity.
Parameters
----------
entity
The entity whose related flags are to be retrieved.
Returns
-------
List[Flag]
A list of all flags related to the entity.
"""
flags: List[Flag] = []
if isinstance(entity, FlagMixIn):
flags += entity.get_flags()
# collect all flags from its ancestors
for node in nx.algorithms.ancestors(self._graph, entity):
if isinstance(node, FlagMixIn):
flags += node.get_flags()
verify_flag_uniqueness(flags)
return flags
[docs]
def get_measure_flags(self, entity: Optional[EntityType] = None) -> List[Flag]:
"""Get all measure flags related to the entity.
If no entity is provided, the default entity is the reading and every measure
flag coming from the reading will be returned.
Parameters
----------
entity
An optional entity whose related measure flags are to be retrieved,
if ``None`` is provided the default entity is the root of the data trace.
Returns
-------
List[Flag]
A list of all measure flags related to the entity. If no entity is
provided, all measure flags are returned.
"""
entity = self.root if entity is None else entity
measure_flags: List[Flag] = []
nodes = chain(
nx.algorithms.ancestors(self._graph, entity),
[entity],
nx.algorithms.descendants(self._graph, entity),
)
for node in nodes:
if isinstance(node, MeasureValue):
measure_flags += node.get_flags()
verify_flag_uniqueness(measure_flags)
return measure_flags