Source code for dispel.providers.generic.tasks.sbt_utt.sbt

"""Static Balance Test module.

A module containing the functionality to process the *Static Balance* test
(SBT).
"""
from typing import Iterable, List, Optional

import pandas as pd
from scipy.signal import medfilt
from scipy.stats import iqr as scipy_iqr

from dispel.data.core import EntityType, Reading
from dispel.data.flags import FlagSeverity, FlagType
from dispel.data.levels import Level
from dispel.data.measures import MeasureValueDefinitionPrototype
from dispel.data.raw import ACCELEROMETER_COLUMNS, RawDataValueDefinition
from dispel.data.validators import GREATER_THAN_ZERO
from dispel.data.values import AbbreviatedValue as AV
from dispel.processing import ProcessingStep
from dispel.processing.data_set import FlagDataSetStep, transformation
from dispel.processing.extract import ExtractStep
from dispel.processing.flags import flag
from dispel.processing.level import LevelIdFilter, ProcessingStepGroup
from dispel.processing.modalities import LimbModality, SensorModality
from dispel.processing.transform import TransformStep
from dispel.providers.generic.flags.generic import (
    FrequencyLowerThanSBTThres,
    MaxTimeIncrement,
)
from dispel.providers.generic.flags.le_flags import (
    FlagAdjustmentNonBeltSBT,
    FlagNonBeltSBT,
    FlagPostAdjustExcessiveMotion,
    FlagStabilisationExcessiveMotion,
    PlacementClassificationGroup,
)
from dispel.providers.generic.preprocessing import (
    FilterPhysiologicalNoise,
    FilterSensorNoise,
    PreprocessingSteps,
)
from dispel.providers.generic.sensor import RenameColumns, TransformUserAcceleration
from dispel.providers.generic.tasks.sbt_utt.const import (
    FIXED_SIGN_AMP_THRESHOLD,
    KERNEL_WINDOW_SEGMENT,
    MIN_COVERAGE_THRESHOLD,
    TASK_NAME_SBT,
)
from dispel.providers.generic.tasks.sbt_utt.sbt_bouts import (
    SBTBoutExtractStep,
    SBTBoutStrategyModality,
)
from dispel.providers.generic.tasks.sbt_utt.sbt_func import (
    circle_area,
    data_coverage_fraction,
    ellipse_area,
    label_bouts,
    reject_short_bouts,
    sway_jerk,
    sway_total_excursion,
)
from dispel.providers.generic.tremor import TremorMeasures
from dispel.signal.accelerometer import (
    AP_ML_COLUMN_MAPPINGS,
    transform_ap_ml_acceleration,
)
from dispel.signal.core import derive_time_data_frame, euclidean_norm


[docs] class SBTTremorMeasuresGroup(ProcessingStepGroup): """Tremor measure for the SBT from accelerometer."""
[docs] def __init__(self, data_set_id, **kwargs): new_column_names = { "userAccelerationX": "x", "userAccelerationY": "y", "userAccelerationZ": "z", } steps = [ RenameColumns(data_set_id, **new_column_names), TremorMeasures( sensor=SensorModality.ACCELEROMETER, data_set_id=f"{data_set_id}_renamed", add_norm=False, ), ] super().__init__(steps, **kwargs)
[docs] class TransformAxisNames(TransformStep): """A raw data set transformation processing step to change axis names. This step is done to match Martinez et al 2012. """ new_data_set_id = "martinez_accelerations" transform_function = transform_ap_ml_acceleration
[docs] def get_definitions(self) -> List[RawDataValueDefinition]: """Get the data set definitions.""" return [ RawDataValueDefinition( column, f"{column} data", data_type="float64", unit="G" ) for column in AP_ML_COLUMN_MAPPINGS.values() ]
[docs] class TransformJerkNorm(TransformStep): """Transform AP/ML acceleration into jerk norm.""" data_set_ids = TransformAxisNames.new_data_set_id # pylint: disable=E1101 new_data_set_id = "jerk_norm" definitions = [ RawDataValueDefinition("jerk_norm", "Jerk norm (AP/ML)", data_type="float64") ]
[docs] @transformation def jerk_norm(self, data: pd.DataFrame) -> pd.Series: """Compute the jerk norm.""" return euclidean_norm(derive_time_data_frame(data[["ap", "ml"]]))
[docs] class FlagExcessiveMotion(TransformStep): """TransformStep that creates a flag signal to mark excessive motion. The flagging segments the accelerometer norm which can be considered an excessive motion (w.r.t, normal sway during balance, e.g., belt placing or re-gaining balance after tripping, based on the amplitude. Developed using the SBT_flag_scenarios. """ data_set_ids = "jerk_norm" new_data_set_id = "flagged_motion" definitions = [RawDataValueDefinition("flag", "Excessive Motion Segments")]
[docs] @staticmethod @transformation def detect_excessive_motion( data: pd.DataFrame, fixed_threshold: float = FIXED_SIGN_AMP_THRESHOLD, kernel_width: int = KERNEL_WINDOW_SEGMENT, ) -> pd.Series: """Flag segments of excessive motion in accelerometer jerk during sway. Output is a raw_dataset with 0 (not excessive motion) or 1 (excessive motion) Parameters ---------- data A data frame with the jerk_norm. fixed_threshold A threshold to cap the amplitude of the jerk that we consider normal to look only in that kind of data for the statistical threshold definition. kernel_width The width of the running window used to median filter the excessive_motion mask. This is to prevent spikes from appearing when the signal surpasses the statistical_threshold for few instances. Returns ------- Series A Series with flagged_motion which becomes a new raw_dataset """ # Identify which part of the magnitude of the accelerometer signal # is greater than a fixed threshold mask = data["jerk_norm"] > fixed_threshold # Define an adaptive threshold based on 2.5 IQR on the quiet # activity periods statistical_threshold = 2.5 * scipy_iqr(data.loc[~mask, "jerk_norm"]) # Detect excessive motion mask_excessive_motion = data.jerk_norm > statistical_threshold # Smoothing the values removing small gap (2+ secs , 50Hz+1 * 2) # TODO think if I can play with decorators here (perfect composition) flag_smooth = pd.Series( medfilt(mask_excessive_motion * 1.0, kernel_size=kernel_width), data.index ) return reject_short_bouts(label_bouts(flag_smooth), flag_smooth)
[docs] class TransformMergeValidSegments(TransformStep): """Mask the AP-ML acceleration, under the flagged_motion. The output contains only segments of AP-ML acceleration which were not identified as an excessive motion. Outputs the merged valid AP/ML/Vertical acceleration data. """ data_set_ids = ["martinez_accelerations", "flagged_motion"] new_data_set_id = "concat_valid_martinez_acc" definitions = [ RawDataValueDefinition("v", "Continuous valid acceleration V data merged"), RawDataValueDefinition("ml", "Continuous valid acceleration ML data merged"), RawDataValueDefinition("ap", "Continuous valid acceleration AP data merged"), ] @transformation def _valid_concat_data(self, data: pd.DataFrame, flag: pd.Series) -> pd.DataFrame: """Take only valid_continuous_data segments and concatenate them. Parameters ---------- data Martinez Accelerations (AP,ML, V) flag Masking flag (1: excessive motion data, 0: valid data) Returns ------- DataFrame Outputs the masked Martinez Accelerations (AP,ML, V) """ # Filtering data masked by excessive motion flag return data.loc[~flag["flag"], ["v", "ml", "ap"]]
[docs] class ExtractValidSegmentCoverage(ExtractStep): """Extract segments of data without excessive motion. The coverage is defined as a ratio, with numerator the length of the valid data (not flagged as invalid) and, denominator being the length of the data contained originally in the recording. This value is standardized to a percentage (0-100) %. It has value between 0 and 100 %. """ data_set_ids = "flagged_motion" definition = MeasureValueDefinitionPrototype( measure_name=AV("coverage valid data", "coverage_valid"), data_type="float64", unit="%", validator=GREATER_THAN_ZERO, description="The portion of data (being 100 % the complete recording " "length) which is valid, i.e., that does not contain " "a signal with excessive motion.", ) transform_function = data_coverage_fraction
[docs] class FlagMinSegmentCoverage(FlagDataSetStep): """Flag minimum length of valid recording without excessive motion. Parameters ---------- acceptance_threshold The threshold below which the data set is to be flagged. If the fed signal does not last for more than ``acceptance_threshold`` of the total recorded test length, then the associated level is flagged. Should be within ``[0, 100] %``. """ data_set_ids = ["flagged_motion"] flag_name = AV("Too short sway", "sway_too_short") flag_type = FlagType.BEHAVIORAL flag_severity = FlagSeverity.DEVIATION reason = ( "After removing segments with excessive motion" "(i.e., behavioural flag bouts) the remaining data is " "too short {percentage_valid} % of the total, i.e., " "less than {threshold} %" ) acceptance_threshold: float = MIN_COVERAGE_THRESHOLD @flag def _min_valid_length_flag( self, data: pd.Series, acceptance_threshold: float = acceptance_threshold ) -> bool: percentage_valid = data_coverage_fraction(data) * 100 self.set_flag_kwargs( threshold=acceptance_threshold, percentage_valid=percentage_valid, ) return percentage_valid >= acceptance_threshold
[docs] def get_flag_targets( self, reading: Reading, level: Optional[Level] = None, **kwargs ) -> Iterable[EntityType]: """Get flag targets.""" assert level is not None, "Level cannot be null" return [level]
[docs] class ExtractSwayTotalExcursion(SBTBoutExtractStep): """Extract Sway Total Excursion. For details about the algorithm, check equation (eq.8) by Prieto(1996) https://doi.org/10.1109/10.532130 Parameters ---------- data_set_ids The data set ids that will be used to extract Sway Total Excursion measures. """
[docs] def __init__(self, data_set_ids: List[str], **kwargs): description = ( "The Sway Total Excursion (a.k.a. TOTEX) computed with the eq.8 " "of Prieto (1996) algorithm. This measure quantifies the total " "amount of acceleration increments in the " "Anterior/Posterior-Medio/Lateral plane of the movement." "It is computed with the bout strategy {bout_strategy_repr}. See" "also https://doi.org/10.1109/10.532130." ) definition = MeasureValueDefinitionPrototype( measure_name=AV("sway total excursion", "totex"), data_type="float64", unit="mG", validator=GREATER_THAN_ZERO, description=description, ) def _compute_totex(data): return sway_total_excursion(data.ap, data.ml) * 1e3 super().__init__(data_set_ids, _compute_totex, definition, **kwargs)
[docs] class ExtractSwayJerk(SBTBoutExtractStep): """Extract Sway Jerk. For details about the algorithm, check Table 2 by Mancini(2012), https://doi.org/10.1186/1743-0003-9-59 It is a complementary measure to the sway areas, as it covers also the amount of sway occurred within a given geometric area. It takes special relevance when algorithms to remove outliers are applied and the timeseries span used for different measures is different. In other words, a normalised version of the sway total excursion. See an example of concomitant use with sway total excursion in Mancini(2012) Parameters ---------- data_set_ids The data set ids that will be used to extract Sway Jerk measures. """
[docs] def __init__(self, data_set_ids: List[str], **kwargs): description = ( "The Sway Jerk provides a means of quantifying the average " "sway occurred on a plane per time unit over the total of " "an assessment." "It is computed with the bout strategy {bout_strategy_repr}. See" " also https://doi.org/10.1186/1743-0003-9-59." ) definition = MeasureValueDefinitionPrototype( measure_name=AV("sway jerk", "jerk"), data_type="float64", unit="mG/s", validator=GREATER_THAN_ZERO, description=description, ) def _compute_jerk(data): return sway_jerk(data.ap, data.ml) * 1e3 super().__init__(data_set_ids, _compute_jerk, definition, **kwargs)
[docs] class ExtractCircleArea(SBTBoutExtractStep): """Extract circle Area. For details about the algorithm, check equation (eq.12) by Prieto(1996) https://doi.org/10.1109/10.532130 Parameters ---------- data_set_ids The data set ids that will be considered to extract Circle Area measures. """
[docs] def __init__(self, data_set_ids: List[str], **kwargs): description = ( "The ellipse area computed with the eq.12 of Prieto (1996) " "algorithm. It is computed with the bout strategy " "{bout_strategy_repr}. See also " "https://doi.org/10.1109/10.532130." ) definition = MeasureValueDefinitionPrototype( measure_name=AV("circle area", "ca"), data_type="float64", unit="microG^2", validator=GREATER_THAN_ZERO, description=description, ) def _compute_circle_area(data): return circle_area(data.ap, data.ml) * 1e6 super().__init__(data_set_ids, _compute_circle_area, definition, **kwargs)
[docs] class ExtractEllipseArea(SBTBoutExtractStep): """Extract Ellipse Area. For details about the algorithm, check equation eq.18 of Schubert(2014) https://doi.org/10.1016/j.gaitpost.2013.09.001) Parameters ---------- data_set_ids The data set ids that will be considered to extract Ellipse Area measures. """
[docs] def __init__(self, data_set_ids: List[str], **kwargs): description = ( "The ellipse area computed with the eq.18 of Schubert (2014) " "algorithm. It is computed with the bout strategy " "{bout_strategy_repr}. This is a PCA-based variation of " "https://doi.org/10.1016/j.gaitpost.2013.09.001." ) definition = MeasureValueDefinitionPrototype( measure_name=AV("ellipse area", "ea"), data_type="float64", unit="microG^2", validator=GREATER_THAN_ZERO, description=description, ) def _compute_ellipse_area(data): return ellipse_area(data.ap, data.ml) * 1e6 super().__init__(data_set_ids, _compute_ellipse_area, definition, **kwargs)
[docs] class ExtractSpatioTemporalMeasuresBatch(ProcessingStepGroup): """Extract postural adjustment measures given a bout strategy."""
[docs] def __init__(self, bout_strategy: SBTBoutStrategyModality, **kwargs): data_set_id = ["martinez_accelerations"] steps: List[ProcessingStep] = [ ExtractSwayTotalExcursion( data_set_id, bout_strategy=bout_strategy.bout_cls ), ExtractSwayJerk(data_set_id, bout_strategy=bout_strategy.bout_cls), ExtractCircleArea(data_set_id, bout_strategy=bout_strategy.bout_cls), ExtractEllipseArea(data_set_id, bout_strategy=bout_strategy.bout_cls), ] super().__init__(steps, **kwargs)
[docs] class TechnicalFlagsGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating TechnicalFlagsGroup steps.""" steps = [ FrequencyLowerThanSBTThres(), MaxTimeIncrement(), ]
[docs] class FormattingGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating Formatting steps.""" steps = [TransformUserAcceleration()]
[docs] class PreProcessingStepsGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating PreProcessingStepsGroup steps.""" steps = [ PreprocessingSteps( "accelerometer", LimbModality.LOWER_LIMB, SensorModality.ACCELEROMETER, columns=ACCELEROMETER_COLUMNS, ), FilterSensorNoise( "acc_ts_rotated_resampled_detrend", columns=ACCELEROMETER_COLUMNS ), FilterPhysiologicalNoise( "acc_ts_rotated_resampled_detrend_svgf", columns=ACCELEROMETER_COLUMNS ), ]
[docs] class TremorAndAxesGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating TremorAndAxesGroup steps.""" steps = [ SBTTremorMeasuresGroup("acc_ts_rotated_resampled_detrend"), TransformAxisNames("acc_ts_rotated_resampled_detrend_svgf_bhpf"), TransformJerkNorm(), ]
[docs] class DetectExcessiveMotionGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating DetectExcessiveMotionGroup steps.""" steps = [ # Steps Group to obtain Sway Path FlagExcessiveMotion(), TransformMergeValidSegments(), ExtractValidSegmentCoverage(), ]
[docs] class ExtractSpatioTemporalMeasuresGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating ExtractSpatioTemporalMeasuresBatch steps.""" steps = [ ExtractSpatioTemporalMeasuresBatch( bout_strategy=bout_strategy, modalities=[bout_strategy.av], bout_strategy_repr=bout_strategy.av, ) for bout_strategy in SBTBoutStrategyModality ]
[docs] class BehaviouralFlagsGroup(ProcessingStepGroup): """A ProcessingStepGroup concatenating BehaviouralFlagsGroup steps.""" steps = [ # Detect Placement Steps PlacementClassificationGroup(), FlagNonBeltSBT(), FlagAdjustmentNonBeltSBT(), FlagStabilisationExcessiveMotion(), FlagPostAdjustExcessiveMotion(), ] kwargs = {"task_name": TASK_NAME_SBT}
[docs] class SBTProcessingSteps(ProcessingStepGroup): """All processing steps to extract SBT measures.""" steps = [ TechnicalFlagsGroup(), FormattingGroup(), PreProcessingStepsGroup(), TremorAndAxesGroup(), DetectExcessiveMotionGroup(), ExtractSpatioTemporalMeasuresGroup(), BehaviouralFlagsGroup(), ] level_filter = LevelIdFilter("sbt") kwargs = {"task_name": TASK_NAME_SBT}