Source code for dispel.processing.trace

"""Module to inspect definitions of processing steps."""
import warnings
from abc import ABC
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
)

import networkx as nx
from mypy_extensions import KwArg

from dispel.data.epochs import EpochDefinition
from dispel.data.raw import RawDataSetDefinition
from dispel.data.values import DefinitionIdType, ValueDefinition
from dispel.processing import ProcessingStepsType
from dispel.processing.core import CoreProcessingStepGroup, Parameter, ProcessingStep
from dispel.processing.data_set import MutateDataSetProcessingStepBase
from dispel.processing.epochs import (
    CreateLevelEpochStep,
    LevelEpochExtractStep,
    LevelEpochProcessingStepMixIn,
)
from dispel.processing.extract import (
    AggregateMeasures,
    ExtractStep,
    MeasureDefinitionMixin,
)
from dispel.processing.level import (
    DefaultLevelFilter,
    LevelFilter,
    LevelFilterProcessingStepMixin,
)
from dispel.processing.transform import TransformStep


[docs] class TraceType(Enum): """An enum to indicate type of trace from inspection.""" ORIGIN = "origin" STEP_GENERIC = "step" STEP_GROUP = "step-group" DATA_SET = "data_set" MEASURE = "measure" EPOCH = "epoch"
[docs] class TraceRelation(Enum): """An enum to denote the relationship between nodes in edges.""" GROUP = "group" SEQUENCE = "sequence" INPUT = "input" OUTPUT = "output"
[docs] @dataclass class Trace(ABC): """A trace of a processing element from the inspection. Attributes ---------- trace_type The type of the trace. This is one of :class:`TraceType` and set in the derived trace classes. """ trace_type: TraceType = field(init=False)
TraceTypeT = TypeVar("TraceTypeT", bound=Trace)
[docs] @dataclass(unsafe_hash=True) class OriginTrace(Trace): """The origin of inspection. This trace is the node to which all other nodes have a path and is added at the start of inspection. """ trace_type = TraceType.ORIGIN
[docs] @dataclass(unsafe_hash=True) class StepTrace(Trace): """The trace of a generic processing step. Attributes ---------- step The traced processing step. """ trace_type = TraceType.STEP_GENERIC step: ProcessingStep
[docs] @dataclass(unsafe_hash=True) class StepGroupTrace(StepTrace): """The trace of a processing group step.""" trace_type = TraceType.STEP_GROUP
@dataclass class _LevelTraceBase(Trace): """A base class for traces of elements belonging to a level. Attributes ---------- level The level filter relevant to the traced element. """ level: LevelFilter
[docs] @dataclass class DataSetTrace(_LevelTraceBase): """The trace of a data set. Attributes ---------- data_set_id The id of the data set traced. definition The raw data set definition. """ trace_type = TraceType.DATA_SET data_set_id: str definition: Optional[RawDataSetDefinition] = None def __hash__(self): # definition is excluded from the hash as it will not be set by # subsequent transform and extract steps. return hash((repr(self.level), self.data_set_id))
[docs] @dataclass class MeasureTrace(_LevelTraceBase): """The trace of a measure. Attributes ---------- measure The definition of the measure """ trace_type = TraceType.MEASURE # FIXME: remove step -> inconsistent with DataTrace and EpochTrace, step is in edge! step: MeasureDefinitionMixin measure: ValueDefinition def __hash__(self): return hash((repr(self.level), self.measure.id))
[docs] @dataclass class EpochTrace(_LevelTraceBase): """The trace of an epoch. Attributes ---------- epoch The definition of the epoch """ trace_type = TraceType.EPOCH epoch: EpochDefinition def __hash__(self): return hash((repr(self.level), self.epoch.id))
LevelTraceTypeT = TypeVar("LevelTraceTypeT", bound=_LevelTraceBase) def _add_trace_node(graph: nx.MultiDiGraph, trace: Trace): graph.add_node(trace, trace_type=trace.trace_type, **trace.__dict__) def _visit_processing_step_group( graph: nx.MultiDiGraph, group: CoreProcessingStepGroup, **kwargs ) -> Tuple[StepGroupTrace, StepTrace]: (updated_kwargs := kwargs.copy()).update(group.get_kwargs()) _add_trace_node(graph, group_trace := StepGroupTrace(group)) previous: StepTrace = group_trace for step in group.steps: trace = _visit_processing_step(graph, previous, step, **updated_kwargs) previous = trace graph.add_edge( group_trace, trace, type=TraceRelation.GROUP, group=group, ) return group_trace, previous
[docs] def get_traces( graph: nx.MultiDiGraph, trace_type: Type[TraceTypeT] ) -> Iterable[TraceTypeT]: """Get all traces being an instance of a specific type. Parameters ---------- graph The graph providing the nodes/traces trace_type The object type the trace has to be of. Returns ------- Iterable[TraceTypeT] Returns an iterable of traces of type ``trace_type``. """ return filter(lambda t: isinstance(t, trace_type), graph.nodes)
def _find_or_add_trace(graph: nx.MultiDiGraph, trace: _LevelTraceBase): if isinstance(trace, DataSetTrace): for candidate in get_traces(graph, DataSetTrace): if candidate.data_set_id == trace.data_set_id: return candidate _add_trace_node(graph, trace) return trace def _get_level_filter(step: ProcessingStep) -> LevelFilter: if not isinstance(step, LevelFilterProcessingStepMixin): return DefaultLevelFilter() return step.get_level_filter() def _find_or_add_input_data_set_traces( graph: nx.MultiDiGraph, step: MutateDataSetProcessingStepBase ) -> Sequence[DataSetTrace]: level_filter = _get_level_filter(step) trace_inputs = [ _find_or_add_trace(graph, DataSetTrace(level_filter, ds_id)) for ds_id in step.get_data_set_ids() ] return trace_inputs def _add_transform_function_output_traces( graph: nx.MultiDiGraph, step: MutateDataSetProcessingStepBase, output_callback: Callable[[KwArg(Any)], LevelTraceTypeT], step_trace: StepTrace, input_traces: Sequence[_LevelTraceBase], **kwargs, ) -> Sequence[LevelTraceTypeT]: output_traces = [] for _, func_kwargs in step.get_transform_functions(): (updated_kwargs := kwargs.copy()).update(func_kwargs) output_trace = _find_or_add_trace(graph, output_callback(**updated_kwargs)) output_traces.append(output_trace) graph.add_edge(step_trace, output_trace, type=TraceRelation.OUTPUT) for input_trace in input_traces: graph.add_edge( input_trace, output_trace, type=TraceRelation.INPUT, step=step, ) return output_traces def _visit_mutate_data_set_step( graph: nx.MultiDiGraph, step: MutateDataSetProcessingStepBase, output_callback: Callable[[KwArg(Any)], LevelTraceTypeT], **kwargs, ) -> Tuple[StepTrace, Sequence[_LevelTraceBase], Sequence[LevelTraceTypeT]]: _add_trace_node(graph, step_trace := StepTrace(step)) input_traces = _find_or_add_input_data_set_traces(graph, step) for input_trace in input_traces: graph.add_edge(input_trace, step_trace, type=TraceRelation.INPUT) output_traces = _add_transform_function_output_traces( graph, step, output_callback, step_trace, input_traces, **kwargs ) return step_trace, input_traces, output_traces def _visit_transform_step( graph: nx.MultiDiGraph, step: TransformStep, **kwargs ) -> StepTrace: def _callback(**_kwargs) -> DataSetTrace: return DataSetTrace( _get_level_filter(step), step.get_new_data_set_id(), step.get_raw_data_set_definition(), ) step_trace, *_ = _visit_mutate_data_set_step(graph, step, _callback, **kwargs) return step_trace def _visit_extract_step( graph: nx.MultiDiGraph, step: ExtractStep, **kwargs ) -> StepTrace: def _callback(**_kwargs) -> MeasureTrace: return MeasureTrace( _get_level_filter(step), step, step.get_definition(**_kwargs), ) step_trace, *_ = _visit_mutate_data_set_step(graph, step, _callback, **kwargs) return step_trace def _create_measure_extract_traces( graph: nx.MultiDiGraph, step: MeasureDefinitionMixin, **kwargs ): assert isinstance(step, ProcessingStep), "step must inherit from ProcessingStep" _add_trace_node(graph, step_trace := StepTrace(step)) _add_trace_node( graph, output_trace := MeasureTrace( _get_level_filter(step), step, step.get_definition(**kwargs), ), ) graph.add_edge(step_trace, output_trace, type=TraceRelation.OUTPUT) return step_trace, output_trace def _visit_aggregate_measures_step( graph: nx.MultiDiGraph, step: AggregateMeasures, **kwargs ) -> StepTrace: step_trace, output_trace = _create_measure_extract_traces(graph, step, **kwargs) measure_trace_dict: Dict[DefinitionIdType, MeasureTrace] = { t.measure.id: t for t in get_traces(graph, MeasureTrace) } for measure_id in step.get_measure_ids(**kwargs): if measure_id not in measure_trace_dict: warnings.warn( f"{measure_id} not observed in inspection. Will not create " f"edges to step and output measure.", UserWarning, ) continue graph.add_edge( measure_trace_dict[measure_id], step_trace, type=TraceRelation.INPUT ) graph.add_edge( measure_trace_dict[measure_id], output_trace, type=TraceRelation.INPUT, step=step, ) return step_trace def _visit_measure_definition_mixin_step( graph: nx.MultiDiGraph, step: MeasureDefinitionMixin, **kwargs ) -> StepTrace: step_trace, _ = _create_measure_extract_traces(graph, step, **kwargs) return step_trace def _visit_processing_step_generic( graph: nx.MultiDiGraph, step: ProcessingStep, **_kwargs ) -> StepTrace: _add_trace_node(graph, step_trace := StepTrace(step)) return step_trace def _find_input_epoch_traces( graph: nx.MultiDiGraph, step: LevelEpochProcessingStepMixIn ) -> Sequence[EpochTrace]: epoch_filter = step.get_epoch_filter() # assuming mixin was done with step that has level-filtering capabilities. assert isinstance(step, LevelFilterProcessingStepMixin) level_filter = step.get_level_filter() epoch_traces = [] epoch_traces_candidates = get_traces(graph, EpochTrace) for trace in epoch_traces_candidates: if trace.level != level_filter or not epoch_filter([trace.epoch]): continue epoch_traces.append(trace) return epoch_traces def _visit_level_epoch_extract_step( graph: nx.MultiDiGraph, step: LevelEpochExtractStep, **kwargs ) -> StepTrace: def _callback(**_kwargs): return MeasureTrace( _get_level_filter(step), step, step.get_definition(**_kwargs), ) step_trace, _, output_traces = _visit_mutate_data_set_step( graph, step, _callback, **kwargs ) for epoch_trace in _find_input_epoch_traces(graph, step): graph.add_edge(epoch_trace, step_trace, type=TraceRelation.INPUT) for output_trace in output_traces: graph.add_edge( epoch_trace, output_trace, step=step, type=TraceRelation.INPUT, ) return step_trace def _visit_create_level_epochs_step( graph: nx.MultiDiGraph, step: CreateLevelEpochStep, **kwargs ) -> StepTrace: level_filter = _get_level_filter(step) def _callback(**_kwargs): return EpochTrace( level_filter, step.get_definition(**_kwargs), ) step_trace, input_traces, output_traces = _visit_mutate_data_set_step( graph, step, _callback, **kwargs ) # Add epoch data set if provided epoch_data_set_trace: Optional[DataSetTrace] = None if step.epoch_data_set_id: epoch_data_set_trace = _find_or_add_trace( graph, DataSetTrace(level_filter, step.epoch_data_set_id) ) graph.add_edge(step_trace, epoch_data_set_trace, type=TraceRelation.OUTPUT) for input_trace in input_traces: graph.add_edge( input_trace, epoch_data_set_trace, step=step, type=TraceRelation.INPUT, ) if isinstance(step, LevelEpochProcessingStepMixIn): for epoch_trace in _find_input_epoch_traces(graph, step): graph.add_edge(epoch_trace, step_trace, type=TraceRelation.INPUT) for output_trace in output_traces: graph.add_edge( epoch_trace, output_trace, step=step, type=TraceRelation.INPUT, ) if epoch_data_set_trace: graph.add_edge( epoch_trace, epoch_data_set_trace, step=step, type=TraceRelation.INPUT, ) return step_trace def _visit_processing_step( graph: nx.MultiDiGraph, previous: Trace, step: ProcessingStep, **kwargs ) -> StepTrace: trace: StepTrace trace_ls: Optional[StepTrace] = None # Processing groups if isinstance(step, CoreProcessingStepGroup): trace, trace_ls = _visit_processing_step_group(graph, step, **kwargs) # Epoch specific transformation & extraction elif isinstance(step, LevelEpochExtractStep): trace = _visit_level_epoch_extract_step(graph, step, **kwargs) elif isinstance(step, CreateLevelEpochStep): trace = _visit_create_level_epochs_step(graph, step, **kwargs) # Extracts elif isinstance(step, ExtractStep): trace = _visit_extract_step(graph, step, **kwargs) elif isinstance(step, AggregateMeasures): trace = _visit_aggregate_measures_step(graph, step, **kwargs) elif isinstance(step, MeasureDefinitionMixin): trace = _visit_measure_definition_mixin_step(graph, step, **kwargs) # Transformations elif isinstance(step, TransformStep): trace = _visit_transform_step(graph, step, **kwargs) # Generics else: trace = _visit_processing_step_generic(graph, step) graph.add_edge(previous, trace, type=TraceRelation.SEQUENCE) # return the last step in the group if visited step is a group if isinstance(step, CoreProcessingStepGroup): assert trace_ls return trace_ls return trace
[docs] def inspect(steps: ProcessingStepsType, **kwargs) -> nx.MultiDiGraph: """Inspect the relationships defined in processing steps. Parameters ---------- steps The processing steps to be inspected. kwargs Any additional processing arguments typically passed to the processing steps. Returns ------- networkx.MultiDiGraph A graph representing the processing elements found in `steps`. """ graph = nx.MultiDiGraph() if isinstance(steps, type): # try to instantiate steps if class-referenced steps = steps() if isinstance(steps, ProcessingStep): steps = [steps] _add_trace_node(graph, origin := OriginTrace()) previous: Trace = origin for step in steps: trace = _visit_processing_step(graph, previous, step, **kwargs) previous = trace return graph
[docs] def get_ancestors( graph: nx.MultiDiGraph, trace: Trace, rel_filter: Optional[Callable[[nx.MultiDiGraph, Trace, Trace], bool]] = None, ) -> Set[Trace]: """Get all ancestors of a trace. Parameters ---------- graph The graph containing the relationship between the traces derived from :func:`inspect`. trace The trace for which to find all ancestors, i.e., all predecessors of all predecessors until the origin. rel_filter If provided, predecessors will only be considered if the return value is `True`. The function needs to accept the graph, the current trace, and the potential predecessor to be considered as ancestor. Returns ------- Set[Trace] A set of traces of all predecessors to `trace`. """ ancestors = set() def _collect(t): for predecessor in graph.predecessors(t): if predecessor not in ancestors: if rel_filter and not rel_filter(graph, t, predecessor): continue ancestors.add(predecessor) _collect(predecessor) _collect(trace) return ancestors
def _rel_filter_sources( _graph: nx.MultiDiGraph, _trace: Trace, predecessor: Trace ) -> bool: return isinstance(predecessor, (DataSetTrace, MeasureTrace))
[docs] def get_ancestor_source_graph( graph: nx.MultiDiGraph, trace: MeasureTrace ) -> nx.MultiDiGraph: """Get a subgraph of all sources leading to the extraction of a measure. Parameters ---------- graph The graph containing all traces. trace A trace of a measure for which to return the source graph. Returns ------- networkx.MultiDiGraph A subgraph of `graph` comprised of nodes and edges only from data sets and measures leading to the extraction of the provided measure through `trace`. """ return nx.subgraph( graph, get_ancestors(graph, trace, _rel_filter_sources) | {trace} )
[docs] def get_edge_parameters(graph: nx.MultiDiGraph) -> Dict[ProcessingStep, Set[Parameter]]: """Get parameters defined in steps attributes of edges.""" parameters = defaultdict(set) for *_, step in graph.edges.data("step"): if step not in parameters and (step_param := step.get_parameters()): parameters[step].update(step_param) return parameters
[docs] def collect_measure_value_definitions( steps: Iterable[ProcessingStep], **kwargs ) -> Iterator[Tuple[MeasureDefinitionMixin, ValueDefinition]]: """Collect all measure value definitions from a list of processing steps. Parameters ---------- steps The steps from which to collect the measure value definitions. kwargs See :func:`inspect`. Yields ------ Tuple[ProcessingStep, ValueDefinition] The processing step in question along with the value definition. """ graph = inspect(steps, **kwargs) for trace in get_traces(graph, MeasureTrace): # FIXME: infer step from input nodes yield trace.step, trace.measure