"""Module to inspect definitions of processing steps."""
import warnings
from abc import ABC
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
)
import networkx as nx
from mypy_extensions import KwArg
from dispel.data.epochs import EpochDefinition
from dispel.data.raw import RawDataSetDefinition
from dispel.data.values import DefinitionIdType, ValueDefinition
from dispel.processing import ProcessingStepsType
from dispel.processing.core import CoreProcessingStepGroup, Parameter, ProcessingStep
from dispel.processing.data_set import MutateDataSetProcessingStepBase
from dispel.processing.epochs import (
    CreateLevelEpochStep,
    LevelEpochExtractStep,
    LevelEpochProcessingStepMixIn,
)
from dispel.processing.extract import (
    AggregateMeasures,
    ExtractStep,
    MeasureDefinitionMixin,
)
from dispel.processing.level import (
    DefaultLevelFilter,
    LevelFilter,
    LevelFilterProcessingStepMixin,
)
from dispel.processing.transform import TransformStep
[docs]
class TraceType(Enum):
    """An enum to indicate type of trace from inspection."""
    ORIGIN = "origin"
    STEP_GENERIC = "step"
    STEP_GROUP = "step-group"
    DATA_SET = "data_set"
    MEASURE = "measure"
    EPOCH = "epoch" 
[docs]
class TraceRelation(Enum):
    """An enum to denote the relationship between nodes in edges."""
    GROUP = "group"
    SEQUENCE = "sequence"
    INPUT = "input"
    OUTPUT = "output" 
[docs]
@dataclass
class Trace(ABC):
    """A trace of a processing element from the inspection.
    Attributes
    ----------
    trace_type
        The type of the trace. This is one of :class:`TraceType` and set in
        the derived trace classes.
    """
    trace_type: TraceType = field(init=False) 
TraceTypeT = TypeVar("TraceTypeT", bound=Trace)
[docs]
@dataclass(unsafe_hash=True)
class OriginTrace(Trace):
    """The origin of inspection.
    This trace is the node to which all other nodes have a path and is added at
    the start of inspection.
    """
    trace_type = TraceType.ORIGIN 
[docs]
@dataclass(unsafe_hash=True)
class StepTrace(Trace):
    """The trace of a generic processing step.
    Attributes
    ----------
    step
        The traced processing step.
    """
    trace_type = TraceType.STEP_GENERIC
    step: ProcessingStep 
[docs]
@dataclass(unsafe_hash=True)
class StepGroupTrace(StepTrace):
    """The trace of a processing group step."""
    trace_type = TraceType.STEP_GROUP 
@dataclass
class _LevelTraceBase(Trace):
    """A base class for traces of elements belonging to a level.
    Attributes
    ----------
    level
        The level filter relevant to the traced element.
    """
    level: LevelFilter
[docs]
@dataclass
class DataSetTrace(_LevelTraceBase):
    """The trace of a data set.
    Attributes
    ----------
    data_set_id
        The id of the data set traced.
    definition
        The raw data set definition.
    """
    trace_type = TraceType.DATA_SET
    data_set_id: str
    definition: Optional[RawDataSetDefinition] = None
    def __hash__(self):
        # definition is excluded from the hash as it will not be set by
        # subsequent transform and extract steps.
        return hash((repr(self.level), self.data_set_id)) 
[docs]
@dataclass
class MeasureTrace(_LevelTraceBase):
    """The trace of a measure.
    Attributes
    ----------
    measure
        The definition of the measure
    """
    trace_type = TraceType.MEASURE
    # FIXME: remove step -> inconsistent with DataTrace and EpochTrace, step is in edge!
    step: MeasureDefinitionMixin
    measure: ValueDefinition
    def __hash__(self):
        return hash((repr(self.level), self.measure.id)) 
[docs]
@dataclass
class EpochTrace(_LevelTraceBase):
    """The trace of an epoch.
    Attributes
    ----------
    epoch
        The definition of the epoch
    """
    trace_type = TraceType.EPOCH
    epoch: EpochDefinition
    def __hash__(self):
        return hash((repr(self.level), self.epoch.id)) 
LevelTraceTypeT = TypeVar("LevelTraceTypeT", bound=_LevelTraceBase)
def _add_trace_node(graph: nx.MultiDiGraph, trace: Trace):
    graph.add_node(trace, trace_type=trace.trace_type, **trace.__dict__)
def _visit_processing_step_group(
    graph: nx.MultiDiGraph, group: CoreProcessingStepGroup, **kwargs
) -> Tuple[StepGroupTrace, StepTrace]:
    (updated_kwargs := kwargs.copy()).update(group.get_kwargs())
    _add_trace_node(graph, group_trace := StepGroupTrace(group))
    previous: StepTrace = group_trace
    for step in group.steps:
        trace = _visit_processing_step(graph, previous, step, **updated_kwargs)
        previous = trace
        graph.add_edge(
            group_trace,
            trace,
            type=TraceRelation.GROUP,
            group=group,
        )
    return group_trace, previous
[docs]
def get_traces(
    graph: nx.MultiDiGraph, trace_type: Type[TraceTypeT]
) -> Iterable[TraceTypeT]:
    """Get all traces being an instance of a specific type.
    Parameters
    ----------
    graph
        The graph providing the nodes/traces
    trace_type
        The object type the trace has to be of.
    Returns
    -------
    Iterable[TraceTypeT]
        Returns an iterable of traces of type ``trace_type``.
    """
    return filter(lambda t: isinstance(t, trace_type), graph.nodes) 
def _find_or_add_trace(graph: nx.MultiDiGraph, trace: _LevelTraceBase):
    if isinstance(trace, DataSetTrace):
        for candidate in get_traces(graph, DataSetTrace):
            if candidate.data_set_id == trace.data_set_id:
                return candidate
    _add_trace_node(graph, trace)
    return trace
def _get_level_filter(step: ProcessingStep) -> LevelFilter:
    if not isinstance(step, LevelFilterProcessingStepMixin):
        return DefaultLevelFilter()
    return step.get_level_filter()
def _find_or_add_input_data_set_traces(
    graph: nx.MultiDiGraph, step: MutateDataSetProcessingStepBase
) -> Sequence[DataSetTrace]:
    level_filter = _get_level_filter(step)
    trace_inputs = [
        _find_or_add_trace(graph, DataSetTrace(level_filter, ds_id))
        for ds_id in step.get_data_set_ids()
    ]
    return trace_inputs
def _add_transform_function_output_traces(
    graph: nx.MultiDiGraph,
    step: MutateDataSetProcessingStepBase,
    output_callback: Callable[[KwArg(Any)], LevelTraceTypeT],
    step_trace: StepTrace,
    input_traces: Sequence[_LevelTraceBase],
    **kwargs,
) -> Sequence[LevelTraceTypeT]:
    output_traces = []
    for _, func_kwargs in step.get_transform_functions():
        (updated_kwargs := kwargs.copy()).update(func_kwargs)
        output_trace = _find_or_add_trace(graph, output_callback(**updated_kwargs))
        output_traces.append(output_trace)
        graph.add_edge(step_trace, output_trace, type=TraceRelation.OUTPUT)
        for input_trace in input_traces:
            graph.add_edge(
                input_trace,
                output_trace,
                type=TraceRelation.INPUT,
                step=step,
            )
    return output_traces
def _visit_mutate_data_set_step(
    graph: nx.MultiDiGraph,
    step: MutateDataSetProcessingStepBase,
    output_callback: Callable[[KwArg(Any)], LevelTraceTypeT],
    **kwargs,
) -> Tuple[StepTrace, Sequence[_LevelTraceBase], Sequence[LevelTraceTypeT]]:
    _add_trace_node(graph, step_trace := StepTrace(step))
    input_traces = _find_or_add_input_data_set_traces(graph, step)
    for input_trace in input_traces:
        graph.add_edge(input_trace, step_trace, type=TraceRelation.INPUT)
    output_traces = _add_transform_function_output_traces(
        graph, step, output_callback, step_trace, input_traces, **kwargs
    )
    return step_trace, input_traces, output_traces
def _visit_transform_step(
    graph: nx.MultiDiGraph, step: TransformStep, **kwargs
) -> StepTrace:
    def _callback(**_kwargs) -> DataSetTrace:
        return DataSetTrace(
            _get_level_filter(step),
            step.get_new_data_set_id(),
            step.get_raw_data_set_definition(),
        )
    step_trace, *_ = _visit_mutate_data_set_step(graph, step, _callback, **kwargs)
    return step_trace
def _visit_extract_step(
    graph: nx.MultiDiGraph, step: ExtractStep, **kwargs
) -> StepTrace:
    def _callback(**_kwargs) -> MeasureTrace:
        return MeasureTrace(
            _get_level_filter(step),
            step,
            step.get_definition(**_kwargs),
        )
    step_trace, *_ = _visit_mutate_data_set_step(graph, step, _callback, **kwargs)
    return step_trace
def _create_measure_extract_traces(
    graph: nx.MultiDiGraph, step: MeasureDefinitionMixin, **kwargs
):
    assert isinstance(step, ProcessingStep), "step must inherit from ProcessingStep"
    _add_trace_node(graph, step_trace := StepTrace(step))
    _add_trace_node(
        graph,
        output_trace := MeasureTrace(
            _get_level_filter(step),
            step,
            step.get_definition(**kwargs),
        ),
    )
    graph.add_edge(step_trace, output_trace, type=TraceRelation.OUTPUT)
    return step_trace, output_trace
def _visit_aggregate_measures_step(
    graph: nx.MultiDiGraph, step: AggregateMeasures, **kwargs
) -> StepTrace:
    step_trace, output_trace = _create_measure_extract_traces(graph, step, **kwargs)
    measure_trace_dict: Dict[DefinitionIdType, MeasureTrace] = {
        t.measure.id: t for t in get_traces(graph, MeasureTrace)
    }
    for measure_id in step.get_measure_ids(**kwargs):
        if measure_id not in measure_trace_dict:
            warnings.warn(
                f"{measure_id} not observed in inspection. Will not create "
                f"edges to step and output measure.",
                UserWarning,
            )
            continue
        graph.add_edge(
            measure_trace_dict[measure_id], step_trace, type=TraceRelation.INPUT
        )
        graph.add_edge(
            measure_trace_dict[measure_id],
            output_trace,
            type=TraceRelation.INPUT,
            step=step,
        )
    return step_trace
def _visit_measure_definition_mixin_step(
    graph: nx.MultiDiGraph, step: MeasureDefinitionMixin, **kwargs
) -> StepTrace:
    step_trace, _ = _create_measure_extract_traces(graph, step, **kwargs)
    return step_trace
def _visit_processing_step_generic(
    graph: nx.MultiDiGraph, step: ProcessingStep, **_kwargs
) -> StepTrace:
    _add_trace_node(graph, step_trace := StepTrace(step))
    return step_trace
def _find_input_epoch_traces(
    graph: nx.MultiDiGraph, step: LevelEpochProcessingStepMixIn
) -> Sequence[EpochTrace]:
    epoch_filter = step.get_epoch_filter()
    # assuming mixin was done with step that has level-filtering capabilities.
    assert isinstance(step, LevelFilterProcessingStepMixin)
    level_filter = step.get_level_filter()
    epoch_traces = []
    epoch_traces_candidates = get_traces(graph, EpochTrace)
    for trace in epoch_traces_candidates:
        if trace.level != level_filter or not epoch_filter([trace.epoch]):
            continue
        epoch_traces.append(trace)
    return epoch_traces
def _visit_level_epoch_extract_step(
    graph: nx.MultiDiGraph, step: LevelEpochExtractStep, **kwargs
) -> StepTrace:
    def _callback(**_kwargs):
        return MeasureTrace(
            _get_level_filter(step),
            step,
            step.get_definition(**_kwargs),
        )
    step_trace, _, output_traces = _visit_mutate_data_set_step(
        graph, step, _callback, **kwargs
    )
    for epoch_trace in _find_input_epoch_traces(graph, step):
        graph.add_edge(epoch_trace, step_trace, type=TraceRelation.INPUT)
        for output_trace in output_traces:
            graph.add_edge(
                epoch_trace,
                output_trace,
                step=step,
                type=TraceRelation.INPUT,
            )
    return step_trace
def _visit_create_level_epochs_step(
    graph: nx.MultiDiGraph, step: CreateLevelEpochStep, **kwargs
) -> StepTrace:
    level_filter = _get_level_filter(step)
    def _callback(**_kwargs):
        return EpochTrace(
            level_filter,
            step.get_definition(**_kwargs),
        )
    step_trace, input_traces, output_traces = _visit_mutate_data_set_step(
        graph, step, _callback, **kwargs
    )
    # Add epoch data set if provided
    epoch_data_set_trace: Optional[DataSetTrace] = None
    if step.epoch_data_set_id:
        epoch_data_set_trace = _find_or_add_trace(
            graph, DataSetTrace(level_filter, step.epoch_data_set_id)
        )
        graph.add_edge(step_trace, epoch_data_set_trace, type=TraceRelation.OUTPUT)
        for input_trace in input_traces:
            graph.add_edge(
                input_trace,
                epoch_data_set_trace,
                step=step,
                type=TraceRelation.INPUT,
            )
    if isinstance(step, LevelEpochProcessingStepMixIn):
        for epoch_trace in _find_input_epoch_traces(graph, step):
            graph.add_edge(epoch_trace, step_trace, type=TraceRelation.INPUT)
            for output_trace in output_traces:
                graph.add_edge(
                    epoch_trace,
                    output_trace,
                    step=step,
                    type=TraceRelation.INPUT,
                )
            if epoch_data_set_trace:
                graph.add_edge(
                    epoch_trace,
                    epoch_data_set_trace,
                    step=step,
                    type=TraceRelation.INPUT,
                )
    return step_trace
def _visit_processing_step(
    graph: nx.MultiDiGraph, previous: Trace, step: ProcessingStep, **kwargs
) -> StepTrace:
    trace: StepTrace
    trace_ls: Optional[StepTrace] = None
    # Processing groups
    if isinstance(step, CoreProcessingStepGroup):
        trace, trace_ls = _visit_processing_step_group(graph, step, **kwargs)
    # Epoch specific transformation & extraction
    elif isinstance(step, LevelEpochExtractStep):
        trace = _visit_level_epoch_extract_step(graph, step, **kwargs)
    elif isinstance(step, CreateLevelEpochStep):
        trace = _visit_create_level_epochs_step(graph, step, **kwargs)
    # Extracts
    elif isinstance(step, ExtractStep):
        trace = _visit_extract_step(graph, step, **kwargs)
    elif isinstance(step, AggregateMeasures):
        trace = _visit_aggregate_measures_step(graph, step, **kwargs)
    elif isinstance(step, MeasureDefinitionMixin):
        trace = _visit_measure_definition_mixin_step(graph, step, **kwargs)
    # Transformations
    elif isinstance(step, TransformStep):
        trace = _visit_transform_step(graph, step, **kwargs)
    # Generics
    else:
        trace = _visit_processing_step_generic(graph, step)
    graph.add_edge(previous, trace, type=TraceRelation.SEQUENCE)
    # return the last step in the group if visited step is a group
    if isinstance(step, CoreProcessingStepGroup):
        assert trace_ls
        return trace_ls
    return trace
[docs]
def inspect(steps: ProcessingStepsType, **kwargs) -> nx.MultiDiGraph:
    """Inspect the relationships defined in processing steps.
    Parameters
    ----------
    steps
        The processing steps to be inspected.
    kwargs
        Any additional processing arguments typically passed to the
        processing steps.
    Returns
    -------
    networkx.MultiDiGraph
        A graph representing the processing elements found in `steps`.
    """
    graph = nx.MultiDiGraph()
    if isinstance(steps, type):
        # try to instantiate steps if class-referenced
        steps = steps()
    if isinstance(steps, ProcessingStep):
        steps = [steps]
    _add_trace_node(graph, origin := OriginTrace())
    previous: Trace = origin
    for step in steps:
        trace = _visit_processing_step(graph, previous, step, **kwargs)
        previous = trace
    return graph 
[docs]
def get_ancestors(
    graph: nx.MultiDiGraph,
    trace: Trace,
    rel_filter: Optional[Callable[[nx.MultiDiGraph, Trace, Trace], bool]] = None,
) -> Set[Trace]:
    """Get all ancestors of a trace.
    Parameters
    ----------
    graph
        The graph containing the relationship between the traces derived
        from :func:`inspect`.
    trace
        The trace for which to find all ancestors, i.e., all predecessors of
        all predecessors until the origin.
    rel_filter
        If provided, predecessors will only be considered if the return
        value is `True`. The function needs to accept the graph, the current
        trace, and the potential predecessor to be considered as ancestor.
    Returns
    -------
    Set[Trace]
        A set of traces of all predecessors to `trace`.
    """
    ancestors = set()
    def _collect(t):
        for predecessor in graph.predecessors(t):
            if predecessor not in ancestors:
                if rel_filter and not rel_filter(graph, t, predecessor):
                    continue
                ancestors.add(predecessor)
                _collect(predecessor)
    _collect(trace)
    return ancestors 
def _rel_filter_sources(
    _graph: nx.MultiDiGraph, _trace: Trace, predecessor: Trace
) -> bool:
    return isinstance(predecessor, (DataSetTrace, MeasureTrace))
[docs]
def get_ancestor_source_graph(
    graph: nx.MultiDiGraph, trace: MeasureTrace
) -> nx.MultiDiGraph:
    """Get a subgraph of all sources leading to the extraction of a measure.
    Parameters
    ----------
    graph
        The graph containing all traces.
    trace
        A trace of a measure for which to return the source graph.
    Returns
    -------
    networkx.MultiDiGraph
        A subgraph of `graph` comprised of nodes and edges only from data
        sets and measures leading to the extraction of the provided measure
        through `trace`.
    """
    return nx.subgraph(
        graph, get_ancestors(graph, trace, _rel_filter_sources) | {trace}
    ) 
[docs]
def get_edge_parameters(graph: nx.MultiDiGraph) -> Dict[ProcessingStep, Set[Parameter]]:
    """Get parameters defined in steps attributes of edges."""
    parameters = defaultdict(set)
    for *_, step in graph.edges.data("step"):
        if step not in parameters and (step_param := step.get_parameters()):
            parameters[step].update(step_param)
    return parameters 
[docs]
def collect_measure_value_definitions(
    steps: Iterable[ProcessingStep], **kwargs
) -> Iterator[Tuple[MeasureDefinitionMixin, ValueDefinition]]:
    """Collect all measure value definitions from a list of processing steps.
    Parameters
    ----------
    steps
        The steps from which to collect the measure value definitions.
    kwargs
        See :func:`inspect`.
    Yields
    ------
    Tuple[ProcessingStep, ValueDefinition]
        The processing step in question along with the value definition.
    """
    graph = inspect(steps, **kwargs)
    for trace in get_traces(graph, MeasureTrace):
        # FIXME: infer step from input nodes
        yield trace.step, trace.measure