Source code for dispel.providers.generic.tasks.ft.steps

"""Finger tapping assessment related functionality.

This module contains functionality to extract measures for the
*Finger tapping* assessment.
"""
# pylint: disable=cell-var-from-loop
from typing import List, Optional

import pandas as pd

from dispel.data.measures import MeasureValueDefinition, MeasureValueDefinitionPrototype
from dispel.data.raw import RawDataValueDefinition
from dispel.data.validators import GREATER_THAN_ZERO
from dispel.data.values import AbbreviatedValue
from dispel.data.values import AbbreviatedValue as AV
from dispel.processing import ProcessingStep
from dispel.processing.data_set import transformation
from dispel.processing.extract import (
    DEFAULT_AGGREGATIONS_IQR,
    AggregateModalities,
    AggregateRawDataSetColumn,
    ExtractStep,
)
from dispel.processing.level import ProcessingStepGroup
from dispel.processing.transform import TransformStep
from dispel.providers.generic.tasks.ft.const import (
    ENRICHED_TOUCH_ATTRIBUTES,
    TASK_NAME,
    AllHandsModalities,
    TappingTarget,
)
from dispel.providers.generic.tasks.ft.flags import GenericFlagsStepGroup


[docs] class TransformInvalidTaps(TransformStep): """Remove the invalid taps. The first tap is always valid if it is inside one of the two zones. A left tap is valid if it follows a right tap and vice versa. If a tap event occurs twice or more, then we keep the first event. Generally speaking, this class returns the in-target consecutive tap events being different. """ data_set_ids = "enriched_tap_events_ts" new_data_set_id = "valid_enriched_tap_events_ts" definitions = [ RawDataValueDefinition(column, column) for column in ["end", "first_position"] + ENRICHED_TOUCH_ATTRIBUTES ]
[docs] @staticmethod @transformation def filter_invalid_taps(data: pd.DataFrame) -> pd.DataFrame: """Return a filtered version of the tap dataset.""" # Remove the 'none' (i.e the taps outside the two zones) # pylint: disable=no-member in_target_events = data.mask( data["location"] == TappingTarget.OUTSIDE.abbr ).dropna() # Filter the consecutive taps on the same zone, and keep the first one valid_events = in_target_events.loc[ in_target_events.location.shift() != in_target_events.location ] valid_events["end"] = valid_events.index + pd.to_timedelta( valid_events["tap_duration"], unit="ms" ) return valid_events
[docs] class ExtractValidTaps(ExtractStep): """Count the number of valid taps.""" data_set_ids = "valid_enriched_tap_events_ts" definition = MeasureValueDefinitionPrototype( measure_name=AV("valid tap events", "valtap"), data_type="int16", validator=GREATER_THAN_ZERO, description="The number of valid tap events on {target} target.", task_name=TASK_NAME, )
[docs] def __init__(self, target: Optional[TappingTarget], *args, **kwargs): self.target = target super().__init__(*args, **kwargs)
[docs] @transformation def count_valid_taps(self, data): """Count the number of valid taps.""" if self.target is not None: return data.location.eq(self.target.abbr).agg("sum") return data.location.agg("count")
[docs] class ExtractTotalTaps(ExtractStep): """Count the number of valid taps.""" data_set_ids = "enriched_tap_events_ts" definition = MeasureValueDefinitionPrototype( measure_name=AV("total tap events", "total_tap"), data_type="int16", validator=GREATER_THAN_ZERO, description="The total number of tap events.", task_name=TASK_NAME, )
[docs] @transformation def count_valid_taps(self, data): """Count the number of valid taps.""" return data.shape[0]
[docs] class AggregateDoubleTaps(ExtractStep): """Count the number of valid taps.""" data_set_ids = "valid_enriched_tap_events_ts" definition = MeasureValueDefinitionPrototype( measure_name=AV("double tap percentage", "double_tap_percentage"), data_type="int16", description="The percentage of time spent double tapping", task_name=TASK_NAME, )
[docs] @transformation def quantify_double_taps(self, df): """Detect double taps.""" data = df.copy() data["delta_overshoot"] = ( data["end"].shift() - data.index.to_series() ).dt.total_seconds() double_tap_total_time = ( data[data["delta_overshoot"] > 0]["delta_overshoot"].sum() * 1e3 ) total_time = data["tap_duration"].sum() return round((double_tap_total_time * 100) / total_time, 2)
[docs] class TransformTapIntervalDistribution(TransformStep): """Extract the valid or all tap interval distribution."""
[docs] def __init__(self, valid_taps=True, **kwargs): self.valid_taps = valid_taps is_valid_prefix = "valid_" if self.valid_taps else "" super().__init__( data_set_ids=f"{is_valid_prefix}enriched_tap_events_ts", new_data_set_id=f"{is_valid_prefix}tap_interval_distribution", definitions=[ RawDataValueDefinition( "interval_with_previous_tap", "Interval with previous tap", data_type="float", ) ], **kwargs, )
[docs] @staticmethod @transformation def extract_tap_distribution(data: pd.DataFrame) -> pd.DataFrame: """Return a filtered version of the tap dataset.""" # Remove the 'none' (i.e the taps outside the two zones) return (data.index - data.index.to_series().shift()).dt.total_seconds()
[docs] class AggregateTapInterval(AggregateRawDataSetColumn): """An extraction processing step to extract valid tap interval measures."""
[docs] def __init__(self, valid_taps=True, **kwargs): self.valid_taps = valid_taps is_valid_prefix = "valid_" if self.valid_taps else "" super().__init__( f"{is_valid_prefix}tap_interval_distribution", "interval_with_previous_tap", aggregations=DEFAULT_AGGREGATIONS_IQR, definition=MeasureValueDefinitionPrototype( measure_name=AbbreviatedValue( "tap interval", f"{is_valid_prefix}tap_inter" ), description="The {aggregation}" + f"of the {is_valid_prefix.replace('_', '')} tap interval distribution for the hand", unit="s", data_type="float", ), **kwargs, )
[docs] class AggregateTaps(AggregateModalities): """Compute the patient score for the FT test.""" definition = MeasureValueDefinition( task_name=TASK_NAME, measure_name=AV("valid tap events", "valtap"), description="The total number of valid tap events during the test", ) modalities = [[hand.av] for hand in AllHandsModalities] aggregation_method = sum
[docs] class PreprocessingStepGroup(ProcessingStepGroup): """Generic preprocessing steps for finger tapping.""" steps = [ TransformInvalidTaps(), TransformTapIntervalDistribution(valid_taps=False), TransformTapIntervalDistribution(), ]
# pylint: disable=no-member
[docs] class MeasureExtractionStepGroup(ProcessingStepGroup): """Generic measure extraction steps for finger tapping.""" steps = [ *[ ProcessingStepGroup( [ExtractValidTaps(target)], task_name=TASK_NAME, modalities=[hand.av, str(target.av)], target=target.av, level_filter=hand.abbr, ) for target in (TappingTarget.LEFT, TappingTarget.RIGHT) for hand in AllHandsModalities ], *[ ProcessingStepGroup( [ ExtractValidTaps(target=None), AggregateTapInterval(valid_taps=False), AggregateTapInterval(), ExtractTotalTaps(), AggregateDoubleTaps(), ], task_name=TASK_NAME, modalities=[hand.av], level_filter=hand.abbr, target=None, ) for hand in AllHandsModalities ], ]
[docs] class MeasureAggregationStepGroup(ProcessingStepGroup): """Generic measure aggregation step group for finger tapping.""" steps = [ # Compute finger tapping patient score from the two levels AggregateTaps() ]
[docs] class GenericFingerTappingSteps(ProcessingStepGroup): """Generic measure aggregation step group for finger tapping.""" steps: List[ProcessingStep] = [ PreprocessingStepGroup(), MeasureExtractionStepGroup(), MeasureAggregationStepGroup(), GenericFlagsStepGroup(), ] kwargs = {"task_name": TASK_NAME}