Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Event detection for ic stride #60

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,055 changes: 4,055 additions & 0 deletions example_data/imu_sample_ic_stride.csv

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions example_data/stride_borders_sample_ic_stride.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
s_id,foot,start,end,gsd_id
0,left,696,812,1
1,left,812,926,1
2,left,1531,1647,1
3,left,1647,1765,1
4,left,2510,2627,1
5,left,2627,2746,1
6,left,3361,3479,1
7,right,757,870,1
8,right,870,986,1
9,right,1591,1709,1
10,right,2451,2570,1
11,right,2570,2686,1
12,right,3306,3424,1
13,right,3424,3544,1
23 changes: 13 additions & 10 deletions gaitmap/_event_detection_common/_event_detection_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from joblib import Memory
from numpy.linalg import norm
from typing_extensions import Self
from typing_extensions import Literal, Self

from gaitmap.utils._algo_helper import invert_result_dictionary, set_params_from_dict
from gaitmap.utils._types import _Hashable
Expand All @@ -22,7 +22,7 @@
)
from gaitmap.utils.exceptions import ValidationError
from gaitmap.utils.stride_list_conversion import (
_segmented_stride_list_to_min_vel_single_sensor,
_stride_list_to_min_vel_single_sensor,
enforce_stride_list_consistency,
)

Expand All @@ -38,16 +38,19 @@ class _EventDetectionMixin:
data: SensorData
sampling_rate_hz: float
stride_list: pd.DataFrame
input_stride_type: Literal["segmented", "ic"]

def __init__(
self,
memory: Optional[Memory] = None,
enforce_consistency: bool = True,
detect_only: Optional[Tuple[str, ...]] = None,
) -> None:
input_stride_type: Literal["segmented", "ic"] = "segmented",
):
self.memory = memory
self.enforce_consistency = enforce_consistency
self.detect_only = detect_only
self.input_stride_type = input_stride_type

def detect(self, data: SensorData, stride_list: StrideList, *, sampling_rate_hz: float) -> Self:
"""Find gait events in data within strides provided by stride_list.
Expand Down Expand Up @@ -121,7 +124,9 @@ def _detect_single_dataset(
# find events in all segments
event_detection_func = self._select_all_event_detection_method()
event_detection_func = memory.cache(event_detection_func)
ic, tc, min_vel = event_detection_func(gyr, acc, stride_list, events=events, **detect_kwargs)
ic, tc, min_vel = event_detection_func(
gyr, acc, stride_list, events=events, input_stride_type=self.input_stride_type, **detect_kwargs
)

# build first dict / df based on segment start and end
segmented_event_list = {
Expand All @@ -132,13 +137,11 @@ def _detect_single_dataset(
for event, event_list in zip(("ic", "tc", "min_vel"), (ic, tc, min_vel)):
if event in events:
segmented_event_list[event] = event_list

segmented_event_list = pd.DataFrame(segmented_event_list).set_index("s_id")

if self.enforce_consistency:
# check for consistency, remove inconsistent strides
segmented_event_list, _ = enforce_stride_list_consistency(
segmented_event_list, stride_type="segmented", check_stride_list=False
segmented_event_list, input_stride_type=self.input_stride_type, check_stride_list=False
)

if "min_vel" not in events or self.enforce_consistency is False:
Expand All @@ -147,16 +150,16 @@ def _detect_single_dataset(
return {"segmented_event_list": segmented_event_list}

# convert to min_vel event list
min_vel_event_list, _ = _segmented_stride_list_to_min_vel_single_sensor(
segmented_event_list, target_stride_type="min_vel"
min_vel_event_list, _ = _stride_list_to_min_vel_single_sensor(
segmented_event_list, source_stride_type=self.input_stride_type, target_stride_type="min_vel"
)

output_order = [c for c in ["start", "end", "ic", "tc", "min_vel", "pre_ic"] if c in min_vel_event_list.columns]

# We enforce consistency again here, as a valid segmented stride list does not necessarily result in a valid
# min_vel stride list
min_vel_event_list, _ = enforce_stride_list_consistency(
min_vel_event_list[output_order], stride_type="min_vel", check_stride_list=False
min_vel_event_list[output_order], input_stride_type="min_vel", check_stride_list=False
)

return {"min_vel_event_list": min_vel_event_list, "segmented_event_list": segmented_event_list}
Expand Down
21 changes: 18 additions & 3 deletions gaitmap/event_detection/_herzer_event_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from joblib import Memory
from scipy import signal
from tpcp import cf
from typing_extensions import Literal

from gaitmap._event_detection_common._event_detection_mixin import _detect_min_vel_gyr_energy, _EventDetectionMixin
from gaitmap.base import BaseEventDetection
Expand Down Expand Up @@ -54,7 +55,8 @@ class HerzerEventDetection(_EventDetectionMixin, BaseEventDetection):
By default, all events ("ic", "tc", "min_vel") are detected.
If `min_vel` is not detected, the `min_vel_event_list_` output will not be available.
If "ic" is not detected, the `pre_ic` will also not be available in the output.

input_stride_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a note here, that only segmented is supported by this algo

The stride list type that should be either "ic", or "segmented".

Attributes
----------
Expand Down Expand Up @@ -132,6 +134,9 @@ class HerzerEventDetection(_EventDetectionMixin, BaseEventDetection):
The window size can be adjusted via the `min_vel_search_win_size_ms` parameter.
This approach is identical to [1]_.

The :func:`~gaitmap.event_detection.HerzerEventDetection.detect` method is implemented only for "segmented" stride
type

The :func:`~gaitmap.event_detection.HerzerEventDetection.detect` method provides a stride list `min_vel_event_list`
with the gait events mentioned above and additionally `start` and `end` of each stride, which are aligned to the
`min_vel` samples.
Expand Down Expand Up @@ -188,6 +193,7 @@ class HerzerEventDetection(_EventDetectionMixin, BaseEventDetection):
ic_lowpass_filter: BaseFilter
memory: Optional[Memory]
enforce_consistency: bool
input_stride_type: Literal["segmented"]

def __init__(
self,
Expand All @@ -198,12 +204,18 @@ def __init__(
memory: Optional[Memory] = None,
enforce_consistency: bool = True,
detect_only: Optional[Tuple[str, ...]] = None,
) -> None:
input_stride_type: Literal["segmented"] = "segmented",
):
self.min_vel_search_win_size_ms = min_vel_search_win_size_ms
self.mid_swing_peak_prominence = mid_swing_peak_prominence
self.mid_swing_n_considered_peaks = mid_swing_n_considered_peaks
self.ic_lowpass_filter = ic_lowpass_filter
super().__init__(memory=memory, enforce_consistency=enforce_consistency, detect_only=detect_only)
super().__init__(
memory=memory,
enforce_consistency=enforce_consistency,
detect_only=detect_only,
input_stride_type=input_stride_type,
)

def _get_detect_kwargs(self) -> Dict[str, int]:
min_vel_search_win_size = int(self.min_vel_search_win_size_ms / 1000 * self.sampling_rate_hz)
Expand Down Expand Up @@ -234,8 +246,11 @@ def _find_all_events(
mid_swing_n_considered_peaks: int,
ic_lowpass_filter: BaseFilter,
sampling_rate_hz: float,
input_stride_type: Literal["segmented"],
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Find events in provided data by looping over single strides."""
if input_stride_type != "segmented":
raise NotImplementedError("This method support only segmented stride type")
gyr_ml = gyr["gyr_ml"].to_numpy()
gyr = gyr.to_numpy()
# inverting acc, as this algorithm was developed assuming a flipped axis like the original Rampp algorithm
Expand Down
31 changes: 31 additions & 0 deletions gaitmap/example_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ def get_healthy_example_imu_data():
return data


def get_healthy_example_imu_data_ic_stride():
"""Get example IMU data from a healthy subject doing a 4x10 gait test.

The sampling rate is 102.4 Hz
"""
test_data_path = _get_data("imu_sample_ic_stride.csv")
data = pd.read_csv(test_data_path, header=[0, 1], index_col=0)

# Get index in seconds
data.index /= 102.4
return data


def get_ms_example_imu_data():
"""Get example IMU data from a MS subject performing a longer uninterrupted walking sequence.

Expand Down Expand Up @@ -116,6 +129,24 @@ def get_healthy_example_stride_borders():
return data


def get_healthy_example_stride_borders_ic_stride():
"""Get hand labeled stride borders for :func:`get_healthy_example_imu_data_ic_stride`.

The stride borders are obtained from mocap where each stride starts with initial contact.
"""
test_data_path = _get_data("stride_borders_sample_ic_stride.csv")
data = pd.read_csv(test_data_path, header=0)

# Convert to dict with sensor name as key.
# Sensor name here is derived from the foot. In the real pipeline that would be provided to the algo.
data["sensor"] = data["foot"] + "_sensor"
data = data.set_index("sensor")
data = data.groupby(level=0)
data = {k: v.reset_index(drop=True) for k, v in data}

return data


def get_healthy_example_mocap_data():
"""Get 3D Mocap information of the foot synchronised with :func:`get_healthy_example_imu_data`.

Expand Down
38 changes: 25 additions & 13 deletions gaitmap/utils/stride_list_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
)


def convert_segmented_stride_list(stride_list: StrideList, target_stride_type: Literal["min_vel", "ic"]) -> StrideList:
"""Convert a segmented stride list with detected events into other types of stride lists.
def convert_stride_list(
stride_list: StrideList,
target_stride_type: Literal["min_vel", "ic"],
source_stride_type: Literal["segmented", "ic"] = "segmented",
) -> StrideList:
"""Convert a stride list with detected events into other types of stride lists.

During the conversion some strides might be removed.
For more information about the different types of stride lists see the :ref:`stride list guide <stride_list_guide>`.
Expand All @@ -30,6 +34,8 @@ def convert_segmented_stride_list(stride_list: StrideList, target_stride_type: L
Stride list to be converted
target_stride_type
The stride list type that should be converted to
source_stride_type
The stride list type that should be converted from

Returns
-------
Expand All @@ -39,17 +45,23 @@ def convert_segmented_stride_list(stride_list: StrideList, target_stride_type: L
"""
stride_list_type = is_stride_list(stride_list, stride_type="segmented")
if stride_list_type == "single":
return _segmented_stride_list_to_min_vel_single_sensor(stride_list, target_stride_type=target_stride_type)[0]
return _stride_list_to_min_vel_single_sensor(
stride_list, target_stride_type=target_stride_type, source_stride_type=source_stride_type
)[0]
return {
k: _segmented_stride_list_to_min_vel_single_sensor(v, target_stride_type=target_stride_type)[0]
k: _stride_list_to_min_vel_single_sensor(
v, target_stride_type=target_stride_type, source_stride_type=source_stride_type
)[0]
for k, v in stride_list.items()
}


def _segmented_stride_list_to_min_vel_single_sensor(
stride_list: SingleSensorStrideList, target_stride_type: Literal["min_vel", "ic"]
def _stride_list_to_min_vel_single_sensor(
stride_list: SingleSensorStrideList,
target_stride_type: Literal["min_vel", "ic"],
source_stride_type: Literal["segmented", "ic"],
) -> Tuple[SingleSensorStrideList, SingleSensorStrideList]:
"""Convert a segmented stride list with detected events into other types of stride lists.
"""Convert a stride list with detected events into other types of stride lists.

During the conversion some strides might be removed.
For more information about the different types of stride lists see the :ref:`stride list guide <stride_list_guide>`.
Expand Down Expand Up @@ -88,7 +100,8 @@ def _segmented_stride_list_to_min_vel_single_sensor(
converted_stride_list["pre_ic"] = converted_stride_list["ic"]
# ic of each stride is the ic in the subsequent segmented stride
converted_stride_list["ic"] = converted_stride_list["ic"].shift(-1)
if "tc" in converted_stride_list.columns:
if "tc" in converted_stride_list.columns and source_stride_type == "segmented":
# do not shift if source_stride_type is "ic"
# tc of each stride is the tc in the subsequent segmented stride
converted_stride_list["tc"] = converted_stride_list["tc"].shift(-1)

Expand All @@ -113,7 +126,7 @@ def _segmented_stride_list_to_min_vel_single_sensor(

def enforce_stride_list_consistency(
stride_list: SingleSensorStrideList,
stride_type=Literal["segmented", "min_vel", "ic"],
input_stride_type=Literal["segmented", "min_vel", "ic"],
check_stride_list: bool = True,
) -> Tuple[SingleSensorStrideList, SingleSensorStrideList]:
"""Exclude those strides where the gait events do not match the expected order or contain NaN.
Expand All @@ -130,7 +143,7 @@ def enforce_stride_list_consistency(
----------
stride_list
A single sensor stride list in a Dataframe format
stride_type
input_stride_type
Indicate which types of strides are expected to be in the stride list.
This changes the expected columns and order of events.
check_stride_list
Expand All @@ -148,9 +161,8 @@ def enforce_stride_list_consistency(

"""
if check_stride_list is True:
is_single_sensor_stride_list(stride_list, stride_type=stride_type, raise_exception=True)
order = SL_EVENT_ORDER[stride_type]

is_single_sensor_stride_list(stride_list, stride_type=input_stride_type, raise_exception=True)
order = SL_EVENT_ORDER[input_stride_type]
order = [c for c in order if c in stride_list.columns]

if len(order) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from joblib import Memory
from tpcp import cf
from typing_extensions import Literal

from gaitmap.data_transform import BaseFilter, ButterworthFilter
from gaitmap_mad.event_detection._rampp_event_detection import RamppEventDetection
Expand Down Expand Up @@ -42,6 +43,8 @@ class FilteredRamppEventDetection(RamppEventDetection):
By default, all events ("ic", "tc", "min_vel") are detected.
If `min_vel` is not detected, the `min_vel_event_list_` output will not be available.
If "ic" is not detected, the `pre_ic` will also not be available in the output.
input_stride_type
The stride_list_type that should be either "ic" or "segmented".

Attributes
----------
Expand Down Expand Up @@ -97,14 +100,17 @@ def __init__(
memory: Optional[Memory] = None,
enforce_consistency: bool = True,
detect_only: Optional[Tuple[str, ...]] = None,
) -> None:
input_stride_type: Literal["segmented", "ic"] = "segmented",
):
self.ic_lowpass_filter = ic_lowpass_filter
self.input_stride_type = input_stride_type
super().__init__(
memory=memory,
enforce_consistency=enforce_consistency,
ic_search_region_ms=ic_search_region_ms,
min_vel_search_win_size_ms=min_vel_search_win_size_ms,
detect_only=detect_only,
input_stride_type=input_stride_type,
)

def _get_detect_kwargs(self) -> Dict:
Expand Down
Loading
Loading