From bbd81daeaf0dac527bd0afa16f70df7c9815c0b5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 14:15:15 +0200 Subject: [PATCH 1/8] Add the ProbeGroup._global_contact_order concept. --- src/probeinterface/probegroup.py | 145 ++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 42 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index d42906a4..e50b3338 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -7,12 +7,21 @@ class ProbeGroup: """ Class to handle a group of Probe objects and the global wiring to a device. - Optionally, it can handle the location of different probes. + Internally, this is represented as a list of Probe object. + + The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe. + + Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order + is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance + some contact of each probe are interleaved, in this case a optional reordering can be applied. + + """ def __init__(self): self.probes = [] + self._global_contact_order = None def add_probe(self, probe: Probe) -> None: """ @@ -114,6 +123,9 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: pg_arr.append(arr_ext) pg_arr = np.concatenate(pg_arr, axis=0) + + if self._global_contact_order is not None: + pg_arr = pg_arr[self._global_contact_order] return pg_arr @staticmethod @@ -121,6 +133,10 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": """Create ProbeGroup from a complex numpy array see ProbeGroup.to_numpy() + Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will + have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order + will be not None. + Parameters ---------- arr : np.array @@ -131,7 +147,13 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": probegroup : ProbeGroup The instantiated ProbeGroup object """ - from .probe import Probe + + # Check if contacts are interleaved + num_probes = np.unique(arr["probe_index"]).size + is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) + print('is_interleaved', is_interleaved) + if is_interleaved: + global_contact_order = [] probes_indices = np.unique(arr["probe_index"]) probegroup = ProbeGroup() @@ -139,6 +161,14 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": mask = arr["probe_index"] == probe_index probe = Probe.from_numpy(arr[mask]) probegroup.add_probe(probe) + + if is_interleaved: + global_contact_order.append(np.flatnonzero(mask)) + + if is_interleaved: + # the argsort is for the 'reverse' order! + probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order)) + return probegroup def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": @@ -181,6 +211,11 @@ def to_dict(self, array_as_list: bool = False) -> dict: for probe_ind, probe in enumerate(self.probes): probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) + if self._global_contact_order is not None: + global_contact_order = self._global_contact_order + if array_as_list: + global_contact_order = global_contact_order.to_list() + d["global_contact_order"] = global_contact_order return d @staticmethod @@ -201,6 +236,11 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) + + global_contact_order = d.get("global_contact_order", None) + if global_contact_order is not None: + probegroup._global_contact_order = np.asarray(global_contact_order) + return probegroup def get_global_device_channel_indices(self) -> np.ndarray: @@ -226,31 +266,40 @@ def get_global_device_channel_indices(self) -> np.ndarray: channels["device_channel_indices"] = arr["device_channel_indices"] return channels - def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: + def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None: """ - Set global indices for all probes + Set global indices for all probes. + + Important note : if the order of contacts is not "natural" then the device_channel_indices + is applied is the real/reordered contacts vector. In short, the device_channel_indices is ziped to + ProbeGroup.to_numpy() (always ordered). Parameters ---------- channels: np.ndarray | list The device channal indices to be set """ - channels = np.asarray(channels) - if channels.size != self.get_contact_count(): + device_channel_indices = np.asarray(device_channel_indices) + if device_channel_indices.size != self.get_contact_count(): raise ValueError( - f"Wrong channels size {channels.size} for the number of channels {self.get_contact_count()}" + f"Wrong channels size {device_channel_indices.size} for the number of channels {self.get_contact_count()}" ) # first reset previous indices for i, probe in enumerate(self.probes): n = probe.get_contact_count() probe.set_device_channel_indices([-1] * n) + + if self._global_contact_order is not None: + # this is tricky conceptually but needed but needed for consistency + rev_order = np.argsort(self._global_contact_order) + device_channel_indices = device_channel_indices[rev_order] # then set new indices ind = 0 for i, probe in enumerate(self.probes): n = probe.get_contact_count() - probe.set_device_channel_indices(channels[ind : ind + n]) + probe.set_device_channel_indices(device_channel_indices[ind : ind + n]) ind += n def get_global_contact_ids(self) -> np.ndarray: @@ -275,6 +324,8 @@ def get_global_contact_positions(self) -> np.ndarray: An array of the contact positions across all probes """ contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) + if self._global_contact_order is not None: + contact_positions = contact_positions[self._global_contact_order] return contact_positions def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": @@ -295,55 +346,65 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": The sliced probe group """ + # TODO SAM order!!! n = self.get_contact_count() selection = np.asarray(selection) + if selection.dtype.kind not in ("b", "i"): + raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + + if selection.dtype == "bool": assert selection.shape == ( n, ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - (selection_indices,) = np.nonzero(selection) - elif selection.dtype.kind == "i": - assert np.unique(selection).size == selection.size - if len(selection) > 0: - assert ( - 0 <= np.min(selection) < n - ), f"An index within your selection is out of bounds {np.min(selection)}" - assert ( - 0 <= np.max(selection) < n - ), f"An index within your selection is out of bounds {np.max(selection)}" - selection_indices = selection - else: - selection_indices = [] - else: - raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + selection_indices = np.flatnonzero(selection) - if len(selection_indices) == 0: - return ProbeGroup() + if len(selection) == 0: + raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") + # return ProbeGroup() - # Map selection to indices of individual probes - ind = 0 - sliced_probes = [] - for probe in self.probes: - n = probe.get_contact_count() - probe_limits = (ind, ind + n) - ind += n + assert np.unique(selection).size == selection.size + assert ( + 0 <= np.min(selection) < n + ), f"An index within your selection is out of bounds {np.min(selection)}" + assert ( + 0 <= np.max(selection) < n + ), f"An index within your selection is out of bounds {np.max(selection)}" + selection_indices = selection - probe_selection_indices = selection_indices[ - (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - ] - if len(probe_selection_indices) == 0: - continue - sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - sliced_probes.append(sliced_probe) - sliced_probe_group = ProbeGroup() - for probe in sliced_probes: - sliced_probe_group.add_probe(probe) + contact_arr = self.to_numpy(complete=True) + contact_arr = contact_arr[selection] + sliced_probe_group = ProbeGroup.from_numpy(contact_arr) + + # TODO annoatation!! return sliced_probe_group + # # Map selection to indices of individual probes + # ind = 0 + # sliced_probes = [] + # for probe in self.probes: + # n = probe.get_contact_count() + # probe_limits = (ind, ind + n) + # ind += n + + # probe_selection_indices = selection_indices[ + # (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) + # ] + # if len(probe_selection_indices) == 0: + # continue + # sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) + # sliced_probes.append(sliced_probe) + + # sliced_probe_group = ProbeGroup() + # for probe in sliced_probes: + # sliced_probe_group.add_probe(probe) + + # return sliced_probe_group + def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() From 16695d74a5242afedb369f3452dbe6a31b3467b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:27:58 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index e50b3338..8dc63b7c 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -134,7 +134,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": see ProbeGroup.to_numpy() Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will - have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order + have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order will be not None. Parameters @@ -151,7 +151,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": # Check if contacts are interleaved num_probes = np.unique(arr["probe_index"]).size is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) - print('is_interleaved', is_interleaved) + print("is_interleaved", is_interleaved) if is_interleaved: global_contact_order = [] @@ -164,7 +164,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": if is_interleaved: global_contact_order.append(np.flatnonzero(mask)) - + if is_interleaved: # the argsort is for the 'reverse' order! probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order)) @@ -236,7 +236,7 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) - + global_contact_order = d.get("global_contact_order", None) if global_contact_order is not None: probegroup._global_contact_order = np.asarray(global_contact_order) @@ -289,7 +289,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | for i, probe in enumerate(self.probes): n = probe.get_contact_count() probe.set_device_channel_indices([-1] * n) - + if self._global_contact_order is not None: # this is tricky conceptually but needed but needed for consistency rev_order = np.argsort(self._global_contact_order) @@ -354,7 +354,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if selection.dtype.kind not in ("b", "i"): raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") - if selection.dtype == "bool": assert selection.shape == ( n, @@ -363,18 +362,13 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if len(selection) == 0: raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") - # return ProbeGroup() + # return ProbeGroup() assert np.unique(selection).size == selection.size - assert ( - 0 <= np.min(selection) < n - ), f"An index within your selection is out of bounds {np.min(selection)}" - assert ( - 0 <= np.max(selection) < n - ), f"An index within your selection is out of bounds {np.max(selection)}" + assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" + assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" selection_indices = selection - contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] sliced_probe_group = ProbeGroup.from_numpy(contact_arr) From 10fac954fca1ce195de0aa4d62491a2f189295f0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 14:55:43 +0200 Subject: [PATCH 3/8] Add some tests for ordering --- src/probeinterface/probegroup.py | 27 +------------- tests/test_probegroup.py | 63 +++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index e50b3338..5bc732e6 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -151,7 +151,6 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": # Check if contacts are interleaved num_probes = np.unique(arr["probe_index"]).size is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) - print('is_interleaved', is_interleaved) if is_interleaved: global_contact_order = [] @@ -291,7 +290,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | probe.set_device_channel_indices([-1] * n) if self._global_contact_order is not None: - # this is tricky conceptually but needed but needed for consistency + # this is tricky conceptually but needed for consistency rev_order = np.argsort(self._global_contact_order) device_channel_indices = device_channel_indices[rev_order] @@ -346,7 +345,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": The sliced probe group """ - # TODO SAM order!!! n = self.get_contact_count() @@ -379,31 +377,10 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": contact_arr = contact_arr[selection] sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - # TODO annoatation!! + # TODO annoatation probe per probe!! return sliced_probe_group - # # Map selection to indices of individual probes - # ind = 0 - # sliced_probes = [] - # for probe in self.probes: - # n = probe.get_contact_count() - # probe_limits = (ind, ind + n) - # ind += n - - # probe_selection_indices = selection_indices[ - # (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - # ] - # if len(probe_selection_indices) == 0: - # continue - # sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - # sliced_probes.append(sliced_probe) - - # sliced_probe_group = ProbeGroup() - # for probe in sliced_probes: - # sliced_probe_group.add_probe(probe) - - # return sliced_probe_group def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index ddd332d4..32ff7d6d 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,8 +6,9 @@ import numpy as np -@pytest.fixture -def probegroup(): + + +def _make_probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() nchan = 0 @@ -21,6 +22,11 @@ def probegroup(): return probegroup +@pytest.fixture +def probegroup(): + return _make_probegroup() + + def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() @@ -200,7 +206,7 @@ def test_copy_is_independent(probegroup): np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions) -# ── get_slice() tests ─────────────────────────────────────────────────────── +# ── get_slice() simple : natural order def test_get_slice_by_bool(probegroup): @@ -232,10 +238,10 @@ def test_get_slice_preserves_positions(probegroup): np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) -def test_get_slice_empty_selection(probegroup): - sliced = probegroup.get_slice(np.array([], dtype=int)) - assert sliced.get_contact_count() == 0 - assert len(sliced.probes) == 0 +# def test_get_slice_empty_selection(probegroup): +# sliced = probegroup.get_slice(np.array([], dtype=int)) +# assert sliced.get_contact_count() == 0 +# assert len(sliced.probes) == 0 def test_get_slice_wrong_bool_size(probegroup): @@ -259,7 +265,46 @@ def test_get_slice_all_contacts(probegroup): probegroup.get_global_contact_positions(), ) +# ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice + +def test_reordred_probegroup(probegroup): + order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + contact_vector = probegroup.to_numpy(complete=True) + contact_vector = contact_vector[order] + + probegroup2 = ProbeGroup.from_numpy(contact_vector) + assert probegroup2._global_contact_order is not None + contact_vector2 = probegroup2.to_numpy(complete=True) + assert np.array_equal(contact_vector, contact_vector2) + + probegroup3 = ProbeGroup.from_dict(probegroup2.to_dict()) + assert probegroup3._global_contact_order is not None + contact_vector3 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector2, contact_vector3) + + probegroup4 = probegroup.get_slice(order) + assert probegroup4._global_contact_order is not None + contact_vector4 = probegroup4.to_numpy(complete=True) + assert np.array_equal(contact_vector3, contact_vector4) + + probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict()) + assert probegroup5._global_contact_order is not None + contact_vector5 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector4, contact_vector5) + + # let go back to original order + rev_order = np.argsort(order) + probegroup6 = probegroup5.get_slice(rev_order) + assert probegroup6._global_contact_order is None + + if __name__ == "__main__": - test_probegroup() - # ~ test_probegroup_3d() + probegroup = _make_probegroup() + + # test_probegroup(probegroup) + # test_probegroup_3d() + test_reordred_probegroup(probegroup) + + From dbd0455d85ebd7c97b556557202a4ad76f3b6b99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:56:54 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 1 - tests/test_probegroup.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index c61d3ee5..9658b978 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -375,7 +375,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": return sliced_probe_group - def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 32ff7d6d..089c642a 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,8 +6,6 @@ import numpy as np - - def _make_probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() @@ -265,14 +263,16 @@ def test_get_slice_all_contacts(probegroup): probegroup.get_global_contact_positions(), ) + # ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice + def test_reordred_probegroup(probegroup): order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) - + contact_vector = probegroup.to_numpy(complete=True) contact_vector = contact_vector[order] - + probegroup2 = ProbeGroup.from_numpy(contact_vector) assert probegroup2._global_contact_order is not None contact_vector2 = probegroup2.to_numpy(complete=True) @@ -296,15 +296,12 @@ def test_reordred_probegroup(probegroup): # let go back to original order rev_order = np.argsort(order) probegroup6 = probegroup5.get_slice(rev_order) - assert probegroup6._global_contact_order is None - + assert probegroup6._global_contact_order is None if __name__ == "__main__": probegroup = _make_probegroup() - + # test_probegroup(probegroup) # test_probegroup_3d() test_reordred_probegroup(probegroup) - - From bfd44b8d939305959751c19599b468b26a4de205 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 15:03:34 +0200 Subject: [PATCH 5/8] oups --- src/probeinterface/probegroup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index c61d3ee5..941f2df9 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -356,7 +356,7 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": assert selection.shape == ( n, ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - selection_indices = np.flatnonzero(selection) + selection = np.flatnonzero(selection) if len(selection) == 0: raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") @@ -365,7 +365,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": assert np.unique(selection).size == selection.size assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" - selection_indices = selection contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] From cc7320be0901bf056939187b47919a2795ec11ec Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 9 Jun 2026 16:10:40 +0200 Subject: [PATCH 6/8] Update src/probeinterface/probegroup.py Co-authored-by: Alessio Buccino --- src/probeinterface/probegroup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 119dcf5f..26f23bff 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -270,7 +270,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | Set global indices for all probes. Important note : if the order of contacts is not "natural" then the device_channel_indices - is applied is the real/reordered contacts vector. In short, the device_channel_indices is ziped to + is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to ProbeGroup.to_numpy() (always ordered). Parameters From 8fb9402a1926d8d92f1f1c2cedcdf9fe8781232f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Jun 2026 15:35:00 +0200 Subject: [PATCH 7/8] Enhance probegroup API (#1) * feat: enhance probegroup API * fix: update BIDS writer/reader to use built-in probe_ids * refac: remove auto_generate_probe_ids * fix: select_contacts should maintain requested contact order * fix: to/from numpy has probe_id, select_contacts behavior, add select_probes * Update src/probeinterface/io.py * tests: extend tests and fix docstring * fix: suggestions from Sam's review * other fixes * fix: ramon's suggestions * test: add test for probe id naming --- examples/ex_03_generate_probe_group.py | 61 ++++ src/probeinterface/__init__.py | 6 +- src/probeinterface/io.py | 27 +- src/probeinterface/probe.py | 8 +- src/probeinterface/probegroup.py | 240 +++++++++++---- src/probeinterface/wiring.py | 19 ++ tests/test_probegroup.py | 406 +++++++++++++++++++++++-- 7 files changed, 669 insertions(+), 98 deletions(-) diff --git a/examples/ex_03_generate_probe_group.py b/examples/ex_03_generate_probe_group.py index 8a640d3a..5986f13f 100644 --- a/examples/ex_03_generate_probe_group.py +++ b/examples/ex_03_generate_probe_group.py @@ -46,4 +46,65 @@ plot_probegroup(probegroup, same_axes=False, with_contact_id=True) +############################################################################## +# Identifying probes with a ``probe_id`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Each probe in a `ProbeGroup` can be given a human-readable ``probe_id`` when +# it is added. This is handy to keep track of which probe targets which brain +# area or hemisphere. If no ``probe_id`` is given, a default one +# (``"probe_1"``, ``"probe_2"``, ...) is generated automatically. + +probe0 = generate_dummy_probe(elec_shapes='square') +probe1 = generate_dummy_probe(elec_shapes='circle') +probe1.move([250, -90]) + +probegroup = ProbeGroup() +probegroup.add_probe(probe0, probe_id="left_hemisphere") +probegroup.add_probe(probe1, probe_id="right_hemisphere") + +print(probegroup) +print("probe_ids:", probegroup.probe_ids) + +############################################################################## +# `ProbeGroup.select_probes()` returns a new `ProbeGroup` with a sub-selection +# of probes given by probe_ids. + +left_hemisphere_probe = probegroup.select_probes(probe_ids=["left_hemisphere"]) +print(left_hemisphere_probe) + +############################################################################## +# We can also select by specific contacts from a probegroup with the +# ``select_contacts`` function. Note that if ``contact_ids`` are not +# unique across probes, you need to disambiguate the selection by specifying the +# probe_ids as well. Otherwise, a ValueError is raised. + +# check if any contact_id is not unique across probes +contact_ids = probegroup.get_global_contact_ids() +if len(contact_ids) != len(set(contact_ids)): + print("contact_ids are not unique across probes, you should provide probe_ids to disambiguate") + +############################################################################## +# Because the contact ids are not unique across probes, combining ``contact_ids`` +# with ``probe_ids`` lets us pull specific contacts from a single hemisphere: + +left_probegroup = probegroup.select_contacts( + contact_ids=["0", "1", "2"], + probe_ids=["left_hemisphere", "left_hemisphere", "left_hemisphere"] +) +print(left_probegroup) + +# Now select contacts from both hemispheres by providing the corresponding probe_ids for each contact_id: +left_and_right_probegroup = probegroup.select_contacts( + contact_ids=["0", "1", "2"], + probe_ids=["left_hemisphere", "right_hemisphere", "left_hemisphere"] +) +print(left_and_right_probegroup) + +# Without providing probe_ids, the selection is ambiguous and an error is raised: +try: + ambiguous_selection = probegroup.select_contacts(contact_ids=["0", "1", "2"]) +except ValueError as e: + print("Error raised for ambiguous selection:", e) + plt.show() diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 45e102bf..ff52e8a3 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -52,4 +52,8 @@ cache_full_library, clear_cache, ) -from .wiring import get_available_pathways +from .wiring import ( + get_available_pathways, + get_pathway, + wire_probe +) diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 90849e81..a6cccc5c 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -203,10 +203,9 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup # create probe object and register with probegroup probe = Probe.from_dataframe(df=df_probe) - probe.annotate(probe_id=probe_id) probes[str(probe_id)] = probe - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=str(probe_id)) ignore_annotations = [ "probe_ids", @@ -326,7 +325,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probegroup = probe_or_probegroup else: raise TypeError( - f"probe_or_probegroup has to be" "of type Probe or ProbeGroup " f"not type: {type(probe_or_probegroup)}" + f"probe_or_probegroup has to be of type Probe or ProbeGroup not type: {type(probe_or_probegroup)}" ) folder = Path(folder) @@ -337,22 +336,12 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup probes = probegroup.probes # Step 1: GENERATION OF PROBE.TSV - # ensure required keys (probe_id, probe_type) are present - - if any("probe_id" not in p.annotations for p in probes): - probegroup.auto_generate_probe_ids() + # ensure required keys (probe_type) are present for probe in probes: - if "probe_id" not in probe.annotations: - raise ValueError( - "Export to BIDS probe format requires " - "the probe id to be specified as an annotation " - "(probe_id). You can do this via " - "`probegroup.auto_generate_ids." - ) if "type" not in probe.annotations: raise ValueError( - "Export to BIDS probe format requires " "the probe type to be specified as an " "annotation (type)" + "Export to BIDS probe format requires the probe type to be specified as an annotation (type)" ) # extract all used annotation keys @@ -361,11 +350,12 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup annotation_keys = np.unique(keys_concatenated) # generate a tsv table capturing probe information - index = range(len([p.annotations["probe_id"] for p in probes])) + index = range(len(probes)) df = pd.DataFrame(index=index) for annotation_key in annotation_keys: df[annotation_key] = [p.annotations[annotation_key] for p in probes] df["n_shanks"] = [len(np.unique(p.shank_ids)) for p in probes] + df["probe_id"] = probegroup.probe_ids # Note: in principle it would also be possible to add the probe width and # depth here based on the probe contour information. However this would @@ -378,8 +368,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup # Step 2: GENERATION OF PROBE.JSON probes_dict = {} - for probe in probes: - probe_id = probe.annotations["probe_id"] + for probe_id, probe in zip(probegroup.probe_ids, probes): probes_dict[probe_id] = { "contour": probe.probe_planar_contour.tolist(), "units": probe.si_units, @@ -403,7 +392,7 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup index = range(sum([p.get_contact_count() for p in probes])) df.rename(columns=tsv_label_map_to_BIDS, inplace=True) - df["probe_id"] = [p.annotations["probe_id"] for p in probes for _ in p.contact_ids] + df["probe_id"] = [probe_id for probe_id, probe in zip(probegroup.probe_ids, probes) for _ in probe.contact_ids] df["coordinate_system"] = ["relative cartesian"] * len(index) channel_indices = [] diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 26284289..0fb695b1 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -534,7 +534,7 @@ def set_device_channel_indices(self, channel_indices: np.ndarray | list): ) self.device_channel_indices = channel_indices if self._probe_group is not None: - self._probe_group.check_global_device_wiring_and_ids() + self._probe_group._check_global_device_wiring_and_ids() def wiring_to_device(self, pathway: str, channel_offset: int = 0): """ @@ -584,7 +584,7 @@ def set_contact_ids(self, contact_ids: np.ndarray | list): self._contact_ids = contact_ids if self._probe_group is not None: - self._probe_group.check_global_device_wiring_and_ids() + self._probe_group._check_global_device_wiring_and_ids() def set_shank_ids(self, shank_ids: np.ndarray | list): """ @@ -1140,8 +1140,10 @@ def from_numpy(arr: np.ndarray) -> "Probe": "plane_axis_y_1", "plane_axis_z_0", "plane_axis_z_1", - "probe_index", "si_units", + # these two are for ProbeGroup to avoid duplication of fields + "probe_index", + "probe_id", ] contact_annotation_fields = [f for f in fields if f not in main_fields] diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 26f23bff..6214599f 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from .utils import generate_unique_ids from .probe import Probe @@ -7,23 +9,30 @@ class ProbeGroup: """ Class to handle a group of Probe objects and the global wiring to a device. - Internally, this is represented as a list of Probe object. - - The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe. - - Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order - is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance - some contact of each probe are interleaved, in this case a optional reordering can be applied. - + Internally, this is represented as a list of Probe objects. + The ProbeGroup is the object saved in the json based probeinterface format, even if there is only one probe. + Tiny detail about contact order: ``ProbeGroup.to_numpy()`` / ``ProbeGroup.to_dataframe()`` return contacts in the + "natural" order (the contacts of each probe stacked one probe after another) unless contacts have become + interleaved across probes. Interleaving can arise from ``get_slice`` or ``select_contacts`` (e.g. selecting + contacts from different probes in an alternating order). When it does, the resulting ``ProbeGroup`` keeps a custom + contact order in the ``_global_contact_order`` attribute so the requested order is preserved. This order is only + ever set internally; there is no public method to set it. """ def __init__(self): - self.probes = [] + self._probes = [] + self._probe_ids = [] self._global_contact_order = None - def add_probe(self, probe: Probe) -> None: + def __repr__(self): + repr_str = f"ProbeGroup: {len(self._probes)} probes - {self.get_contact_count()} contacts" + if self._global_contact_order is not None: + repr_str += " (with custom global contact order)" + return repr_str + + def add_probe(self, probe: Probe, probe_id: str = None) -> None: """ Add an additional probe to the ProbeGroup @@ -31,14 +40,49 @@ def add_probe(self, probe: Probe) -> None: ---------- probe: Probe The probe to add to the ProbeGroup + probe_id: str, optional + The ID to assign to the probe. If None, a unique ID will be generated, + unless a probe_id is already present in the probe's annotations, + in which case that will be used. """ - if len(self.probes) > 0: + if len(self._probes) > 0: self._check_compatible(probe) - self.probes.append(probe) + probe_id_annotation = probe.annotations.get("probe_id", None) + + if probe_id is None: + if probe_id_annotation is not None: + probe_id = probe_id_annotation + else: + existing_int_ids = [int(pid) for pid in self._probe_ids if pid.isdigit()] + probe_id = str(max(existing_int_ids, default=-1) + 1) + else: + if probe_id_annotation is not None and probe_id != probe_id_annotation: + warnings.warn( + f"Provided probe_id '{probe_id}' does not match probe's annotation 'probe_id' " + f"({probe_id_annotation}). Using provided probe_id." + ) + + if probe_id in self._probe_ids: + raise ValueError(f"Probe ID '{probe_id}' is already used in this ProbeGroup.") + self._probe_ids.append(probe_id) + + self._probes.append(probe) probe._probe_group = self + @property + def probe_dict(self) -> dict: + return {probe_id: probe for probe_id, probe in zip(self._probe_ids, self._probes)} + + @property + def probes(self) -> list: + return self._probes + + @property + def probe_ids(self) -> list: + return self._probe_ids + def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: raise ValueError( @@ -51,9 +95,11 @@ def _check_compatible(self, probe: Probe) -> None: ) # check global channel maps - self.probes.append(probe) - self.check_global_device_wiring_and_ids() - self.probes = self.probes[:-1] + self._probes.append(probe) + self._probe_ids.append(f"{len(self._probes)-1}") + self._check_global_device_wiring_and_ids() + self._probes = self.probes[:-1] + self._probe_ids = self.probe_ids[:-1] @property def ndim(self) -> int: @@ -68,12 +114,7 @@ def copy(self) -> "ProbeGroup": copy: ProbeGroup A copy of the ProbeGroup """ - copy = ProbeGroup() - for probe in self.probes: - copy.add_probe(probe.copy()) - global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"] - copy.set_global_device_channel_indices(global_device_channel_indices) - return copy + return ProbeGroup.from_dict(self.to_dict(array_as_list=False)) def get_contact_count(self) -> int: """ @@ -102,7 +143,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: probe_arr = [] # loop over probes to get all fields - dtype = [("probe_index", "int64")] + dtype = [("probe_index", "int64"), ("probe_id", "U100")] fields = [] for probe_index, probe in enumerate(self.probes): arr = probe.to_numpy(complete=complete) @@ -117,6 +158,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: arr = probe_arr[probe_index] arr_ext = np.zeros(probe.get_contact_count(), dtype=dtype) arr_ext["probe_index"] = probe_index + arr_ext["probe_id"] = self._probe_ids[probe_index] for k in fields: if k in arr.dtype.fields: arr_ext[k] = arr[k] @@ -154,12 +196,13 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": if is_interleaved: global_contact_order = [] - probes_indices = np.unique(arr["probe_index"]) + probes_indices = np.sort(np.unique(arr["probe_index"])) probegroup = ProbeGroup() for probe_index in probes_indices: mask = arr["probe_index"] == probe_index + probe_id = arr["probe_id"][mask][0] probe = Probe.from_numpy(arr[mask]) - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=probe_id) if is_interleaved: global_contact_order.append(np.flatnonzero(mask)) @@ -207,13 +250,14 @@ def to_dict(self, array_as_list: bool = False) -> dict: """ d = {} d["probes"] = [] - for probe_ind, probe in enumerate(self.probes): + for probe in self.probes: probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) + d["probe_ids"] = self.probe_ids if self._global_contact_order is not None: global_contact_order = self._global_contact_order if array_as_list: - global_contact_order = global_contact_order.to_list() + global_contact_order = global_contact_order.tolist() d["global_contact_order"] = global_contact_order return d @@ -232,9 +276,12 @@ def from_dict(d: dict) -> "ProbeGroup": The instantiated ProbeGroup object """ probegroup = ProbeGroup() - for probe_dict in d["probes"]: + probe_ids = d.get("probe_ids", None) + if probe_ids is None: + probe_ids = [str(i) for i in range(len(d["probes"]))] + for probe_id, probe_dict in zip(probe_ids, d["probes"]): probe = Probe.from_dict(probe_dict) - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=probe_id) global_contact_order = d.get("global_contact_order", None) if global_contact_order is not None: @@ -242,6 +289,7 @@ def from_dict(d: dict) -> "ProbeGroup": return probegroup + # TODO: this should only return the device_channel_indices, not the probe_index!!! def get_global_device_channel_indices(self) -> np.ndarray: """ Gets the global device channels indices and returns as @@ -267,15 +315,15 @@ def get_global_device_channel_indices(self) -> np.ndarray: def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None: """ - Set global indices for all probes. + Set global device channel indices for all probes. - Important note : if the order of contacts is not "natural" then the device_channel_indices - is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to + Important note: if the probegroup has ``_global_contact_order``, then the device_channel_indices + are reordered before being set. In short, the ``device_channel_indices`` is zipped to ProbeGroup.to_numpy() (always ordered). Parameters ---------- - channels: np.ndarray | list + device_channel_indices: np.ndarray | list The device channal indices to be set """ device_channel_indices = np.asarray(device_channel_indices) @@ -308,7 +356,7 @@ def get_global_contact_ids(self) -> np.ndarray: Returns ------- contact_ids: np.ndarray - An array of the contaact ids across all probes + An array of the contact ids across all probes """ contact_ids = self.to_numpy(complete=True)["contact_ids"] return contact_ids @@ -370,42 +418,118 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": contact_arr = contact_arr[selection] sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - # TODO annoatation probe per probe!! + # Map annotations of the original probegroup to the sliced one + for probe_id, new_probe in zip(sliced_probe_group.probe_ids, sliced_probe_group.probes): + original_probe_index = self.probe_ids.index(probe_id) + orig_probe = self.probes[original_probe_index] - return sliced_probe_group + for k in orig_probe.annotations: + if k not in new_probe.annotations: + new_probe.annotate(**{k: orig_probe.annotations[k]}) - def check_global_device_wiring_and_ids(self) -> None: - # check unique device_channel_indices for !=-1 - chans = self.get_global_device_channel_indices() - keep = chans["device_channel_indices"] >= 0 - valid_chans = chans[keep]["device_channel_indices"] - - if valid_chans.size != np.unique(valid_chans).size: - raise ValueError("channel device indices are not unique across probes") + return sliced_probe_group - def auto_generate_probe_ids(self, *args, **kwargs) -> None: + def select_probes(self, probe_ids: str | np.ndarray | list) -> "ProbeGroup": """ - Annotate all probes with unique probe_id values. + Get a copy of the ProbeGroup with a sub selection of probes based on probe ids. Parameters ---------- - *args: will be forwarded to `probeinterface.utils.generate_unique_ids` - **kwargs: will be forwarded to - `probeinterface.utils.generate_unique_ids` + probe_ids : str | np.array or list + The probe id or ids to select. + + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group """ + if probe_ids is None: + raise ValueError("probe_ids must be provided for selection.") - if any("probe_id" in p.annotations for p in self.probes): - raise ValueError("Probe already has a `probe_id` annotation.") + if isinstance(probe_ids, str): + probe_ids = [probe_ids] - if not args: - args = 1e7, 1e8 - # 3rd argument has to be the number of probes - args = args[:2] + (len(self.probes),) + probe_ids = np.asarray(probe_ids) + if any(probe_id not in self.probe_ids for probe_id in probe_ids): + raise ValueError(f"Some probe_ids {probe_ids} are not present in the ProbeGroup.") + + # selection keeps the order of the to_numpy vector + all_probe_ids = self.to_numpy(complete=True)["probe_id"] + keep_inds = np.flatnonzero(np.isin(all_probe_ids, probe_ids)) + return self.get_slice(keep_inds) + + def select_contacts( + self, contact_ids: np.ndarray | list, probe_ids: np.ndarray | list | None = None + ) -> "ProbeGroup": + """ + Get a copy of the ProbeGroup with a sub selection of contacts based on contact ids and probe ids. + + Parameters + ---------- + contact_ids : np.array or list + The contact ids to select. + probe_ids : np.array or list or None, default: None + If multiple probes and contact ids not unique across probes, an array with the same length + as contact ids to specify which probe each contact id belongs to. - # creating unique probe ids in case probes do not have any yet - probe_ids = generate_unique_ids(*args, **kwargs).astype(str) - for pid, probe in enumerate(self.probes): - probe.annotate(probe_id=probe_ids[pid]) + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group + """ + # both arrays are in the global contact order + arr = self.to_numpy(complete=True) + all_contact_ids = arr["contact_ids"] + all_probe_ids = arr["probe_id"] + + contact_ids = np.asarray(contact_ids) + + if probe_ids is None: + # without probe_ids the request must be unambiguous: each requested contact + # id must appear once in the request and match a single contact in the group + unique_requested, counts = np.unique(contact_ids, return_counts=True) + duplicated = unique_requested[counts > 1] + if duplicated.size > 0: + raise ValueError( + f"contact_ids must be unique, but {duplicated.tolist()} appear more than once. " + "If the same contact id is on multiple probes, use probe_ids to disambiguate." + ) + probe_ids = [None] * len(contact_ids) + else: + if len(probe_ids) != len(contact_ids): + raise ValueError( + f"probe_ids must be the same length as contact_ids, but got {len(probe_ids)} probe_ids and {len(contact_ids)} contact_ids." + ) + + indices = [] + for contact_id, probe_id in zip(contact_ids, probe_ids): + # find the contact id within the specified probe + in_probe_mask = np.ones(all_contact_ids.size, dtype=bool) if probe_id is None else all_probe_ids == probe_id + matches = np.flatnonzero((all_contact_ids == contact_id) & in_probe_mask) + + if matches.size == 0: + raise ValueError(f"contact_id {contact_id} not found in probe {probe_id}") + elif matches.size > 1: + raise ValueError( + f"contact_id {contact_id} is not unique within probe {probe_id}, " + "this should not happen unless the probe has duplicate contact ids" + ) + if matches[0] in indices: + raise ValueError( + f"contact_id {contact_id} matches multiple probes; " + "pass probe_ids to disambiguate which probe each contact_id belongs to." + ) + indices.append(matches[0]) + return self.get_slice(indices) + + def _check_global_device_wiring_and_ids(self) -> None: + # check unique device_channel_indices for !=-1 + chans = self.get_global_device_channel_indices() + keep = chans["device_channel_indices"] >= 0 + valid_chans = chans[keep]["device_channel_indices"] + + if valid_chans.size != np.unique(valid_chans).size: + raise ValueError("channel device indices are not unique across probes") def auto_generate_contact_ids(self, *args, **kwargs) -> None: """ diff --git a/src/probeinterface/wiring.py b/src/probeinterface/wiring.py index 8378ad7b..6f8cb900 100644 --- a/src/probeinterface/wiring.py +++ b/src/probeinterface/wiring.py @@ -82,6 +82,25 @@ def get_available_pathways() -> list: return list(pathways.keys()) +def get_pathway(pathway: str) -> np.ndarray: + """Return the channel indices for a given pathway + + Parameters + ---------- + pathway : str + The pathway to use + + Returns + ------- + chan_indices : np.ndarray + The channel indices for the given pathway + """ + assert pathway in pathways, ( + f"{pathway} is not a currently supported pathway " f"run `get_available_pathways to see options" + ) + return np.array(pathways[pathway], dtype="int64") + + def wire_probe(probe: "Probe", pathway: str, channel_offset: int = 0): """Inplace wiring for a Probe using a pathway diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 089c642a..deeb9e9e 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -15,7 +15,7 @@ def _make_probegroup(): probe.move([i * 100, i * 80]) n = probe.get_contact_count() probe.set_device_channel_indices(np.arange(n) + nchan) - probegroup.add_probe(probe) + probegroup.add_probe(probe, probe_id=f"probe_00{i}") nchan += n return probegroup @@ -39,17 +39,7 @@ def test_probegroup(probegroup): d = probegroup.to_dict() other = ProbeGroup.from_dict(d) - - # checking automatic generation of ids with new dummy probes - probegroup.probes = [] - for i in range(3): - probegroup.add_probe(generate_dummy_probe()) - probegroup.auto_generate_contact_ids() - probegroup.auto_generate_probe_ids() - - for p in probegroup.probes: - assert p.contact_ids is not None - assert "probe_id" in p.annotations + assert probegroup.probe_ids == other.probe_ids def test_probegroup_3d(): @@ -96,7 +86,7 @@ def test_set_contact_ids_rejects_within_probe_duplicates(): probe = Probe(ndim=2, si_units="um") probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) - with pytest.raises(ValueError, match="unique within a Probe"): + with pytest.raises(ValueError): probe.set_contact_ids(["a", "a"]) @@ -108,7 +98,7 @@ def test_set_contact_ids_rejects_wrong_size(): probe = Probe(ndim=2, si_units="um") probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) - with pytest.raises(ValueError, match="do not have the same size"): + with pytest.raises(ValueError): probe.set_contact_ids(["a", "b", "c"]) @@ -264,10 +254,95 @@ def test_get_slice_all_contacts(probegroup): ) +# ── get_slice : probe annotations and probe_ids propagation ───────────────── + + +def _annotated_probegroup(): + """ProbeGroup with 3 probes, each carrying distinct annotations and probe_id.""" + pg = ProbeGroup() + for i in range(3): + probe = generate_dummy_probe() + probe.move([i * 200, 0]) + probe.annotate(brain_area=f"area_{i}", shank=f"s{i}") + pg.add_probe(probe, probe_id=f"probe_{i}") + return pg + + +def test_get_slice_propagates_annotations(): + """Annotations of each original probe are propagated to the sliced probe.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # take a few contacts from each of the 3 probes + sel = np.array([0, 1, n_each, n_each + 1, 2 * n_each, 2 * n_each + 1]) + sub = pg.get_slice(sel) + + assert len(sub.probes) == 3 + for i, probe in enumerate(sub.probes): + assert probe.annotations["brain_area"] == f"area_{i}" + assert probe.annotations["shank"] == f"s{i}" + + +def test_get_slice_maps_annotations_to_correct_probe_when_skipping(): + """ + When the selection skips a middle probe, annotations must still map to the + correct sliced probe (not shift by position). + """ + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # contacts only from probe 0 and probe 2 (probe 1 is skipped entirely) + sel = np.zeros(pg.get_contact_count(), dtype=bool) + sel[0:3] = True + sel[2 * n_each : 2 * n_each + 4] = True + sub = pg.get_slice(sel) + + assert len(sub.probes) == 2 + # first sliced probe corresponds to original probe 0, second to original probe 2 + assert sub.probes[0].annotations["brain_area"] == "area_0" + assert sub.probes[1].annotations["brain_area"] == "area_2" + assert sub.probes[0].get_contact_count() == 3 + assert sub.probes[1].get_contact_count() == 4 + + +def test_get_slice_sets_probe_ids(): + """probe_ids are carried over to the sliced ProbeGroup.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + sel = np.array([0, 1, n_each, 2 * n_each]) + sub = pg.get_slice(sel) + assert sub.probe_ids == ["probe_0", "probe_1", "probe_2"] + + +def test_get_slice_sets_probe_ids_when_skipping(): + """probe_ids reflect only the probes present in the selection, in order.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + # contacts only from probe 0 and probe 2 + sel = np.array([0, 2 * n_each]) + sub = pg.get_slice(sel) + assert len(sub.probes) == 2 + assert sub.probe_ids == ["probe_0", "probe_2"] + + +def test_get_slice_single_probe_keeps_probe_id_and_annotations(): + """Slicing contacts from a single probe keeps that probe's id and annotations.""" + pg = _annotated_probegroup() + n_each = pg.probes[0].get_contact_count() + + sel = np.arange(n_each, n_each + 3) # only probe 1 + sub = pg.get_slice(sel) + assert len(sub.probes) == 1 + assert sub.probe_ids == ["probe_1"] + assert sub.probes[0].annotations["brain_area"] == "area_1" + + # ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice -def test_reordred_probegroup(probegroup): +def test_reordered_probegroup(probegroup): order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) contact_vector = probegroup.to_numpy(complete=True) @@ -290,7 +365,7 @@ def test_reordred_probegroup(probegroup): probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict()) assert probegroup5._global_contact_order is not None - contact_vector5 = probegroup3.to_numpy(complete=True) + contact_vector5 = probegroup5.to_numpy(complete=True) assert np.array_equal(contact_vector4, contact_vector5) # let go back to original order @@ -299,9 +374,306 @@ def test_reordred_probegroup(probegroup): assert probegroup6._global_contact_order is None +def _interleaved_order(): + """An order interleaving contacts across probes (non-natural).""" + return np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + +def test_global_contact_order_natural_is_none(probegroup): + """A non-interleaved (natural) contact vector does not set a custom order.""" + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)) + assert pg._global_contact_order is None + + +def test_global_contact_order_positions_reflect_order(probegroup): + """get_global_contact_positions follows the custom global contact order.""" + order = _interleaved_order() + natural_positions = probegroup.get_global_contact_positions().copy() + + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + assert pg._global_contact_order is not None + np.testing.assert_array_equal(pg.get_global_contact_positions(), natural_positions[order]) + + +def test_global_contact_order_ids_reflect_order(probegroup): + """get_global_contact_ids follows the custom global contact order.""" + order = _interleaved_order() + natural_ids = probegroup.get_global_contact_ids().copy() + + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + np.testing.assert_array_equal(pg.get_global_contact_ids(), natural_ids[order]) + + +def test_global_contact_order_device_channel_indices_roundtrip(probegroup): + """ + With a custom global contact order, device_channel_indices are zipped to the + (reordered) to_numpy() vector. Setting them must roundtrip through both + to_numpy() and get_global_device_channel_indices(). + """ + order = _interleaved_order() + pg = ProbeGroup.from_numpy(probegroup.to_numpy(complete=True)[order]) + assert pg._global_contact_order is not None + + n = pg.get_contact_count() + device_channel_indices = np.arange(n) + pg.set_global_device_channel_indices(device_channel_indices) + + got = pg.to_numpy(complete=True)["device_channel_indices"] + np.testing.assert_array_equal(got, device_channel_indices) + + got_getter = pg.get_global_device_channel_indices()["device_channel_indices"] + np.testing.assert_array_equal(got_getter, device_channel_indices) + + +# ── select_contacts() tests ───────────────────────────────────────────────── + + +def _probegroup_with_contact_ids(unique=True): + """ProbeGroup with 3 probes whose contact_ids are unique (or duplicated) across probes.""" + pg = ProbeGroup() + for i in range(3): + probe = generate_dummy_probe() + probe.move([i * 100, i * 80]) + n = probe.get_contact_count() + if unique: + probe.set_contact_ids([f"p{i}c{j}" for j in range(n)]) + else: + probe.set_contact_ids([f"c{j}" for j in range(n)]) + pg.add_probe(probe) + return pg + + +def test_select_contacts_unique_ids(): + """Selecting by globally unique contact ids returns exactly those contacts.""" + pg = _probegroup_with_contact_ids(unique=True) + selected_ids = ["p0c0", "p0c1", "p2c5"] + sub = pg.select_contacts(selected_ids) + + assert sub.get_contact_count() == 3 + # contacts come from two distinct probes + assert len(sub.probes) == 2 + assert set(sub.get_global_contact_ids()) == set(selected_ids) + + +def test_select_contacts_single_probe(): + """Selecting contacts from a single probe keeps a single probe.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_contacts(["p1c0", "p1c1", "p1c2"]) + assert sub.get_contact_count() == 3 + assert len(sub.probes) == 1 + + +def test_select_contacts_ambiguous_ids_without_probe_ids_raises(): + """ + Without probe_ids, a contact id that exists on more than one probe is + ambiguous and raises a ValueError naming the offending id(s). + """ + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError, match="c0"): + pg.select_contacts(["c0"]) + + +def test_select_contacts_with_probe_ids(): + """probe_ids (paired with contact_ids) disambiguate duplicated contact ids.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_contacts(["c0", "c1"], probe_ids=["1", "1"]) + assert sub.get_contact_count() == 2 + assert len(sub.probes) == 1 + np.testing.assert_array_equal(sorted(sub.get_global_contact_ids()), ["c0", "c1"]) + + +def test_select_contacts_same_id_across_probes_with_probe_ids(): + """The same contact id can be selected from several probes using probe_ids.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_contacts(["c0", "c0"], probe_ids=["0", "2"]) + assert sub.get_contact_count() == 2 + assert len(sub.probes) == 2 + + +def test_select_contacts_probe_ids_length_mismatch_raises(): + """probe_ids must have the same length as contact_ids.""" + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError): + pg.select_contacts(["c0", "c1"], probe_ids=["0"]) + + +def test_select_contacts_too_many_ids_without_probe_ids_raises(): + """ + Requesting more contact ids than the number of unique ids without probe_ids + raises a ValueError. + """ + pg = _probegroup_with_contact_ids(unique=False) + n_unique = len(np.unique(pg.get_global_contact_ids())) + too_many = [f"c{j}" for j in range(n_unique + 1)] + with pytest.raises(ValueError): + pg.select_contacts(too_many) + + +def test_select_contacts_follows_requested_order(): + """The selection follows the order of the provided contact_ids, even across probes.""" + pg = _probegroup_with_contact_ids(unique=True) + # interleave contacts from different probes in a non-natural order + selected_ids = ["p2c5", "p0c1", "p1c0", "p0c0"] + sub = pg.select_contacts(selected_ids) + + np.testing.assert_array_equal(sub.get_global_contact_ids(), selected_ids) + + # positions must follow the same order as the requested ids + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + expected = np.vstack([all_positions[all_ids == cid] for cid in selected_ids]) + np.testing.assert_array_equal(sub.get_global_contact_positions(), expected) + + +def test_select_probes_keeps_every_contact_of_matching_probes(): + """select_probes keeps every contact of the matching probes.""" + pg = _probegroup_with_contact_ids(unique=False) + n_per_probe = pg.probes[0].get_contact_count() + + sub_str = pg.select_probes("1") + assert sub_str.get_contact_count() == n_per_probe + assert len(sub_str.probes) == 1 + + sub_one = pg.select_probes(["1"]) + assert sub_one.get_contact_count() == n_per_probe + assert len(sub_one.probes) == 1 + + sub_two = pg.select_probes(["1", "2"]) + assert sub_two.get_contact_count() == 2 * n_per_probe + assert len(sub_two.probes) == 2 + + +def test_select_probes_keeps_array_order(): + """select_probes preserves the contact order.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_probes(["2", "0"]) + # even if we requested probes in a different order, the contacts are still ordered by their original global order + probe_index_per_contact = sub.to_numpy(complete=True)["probe_id"] + assert probe_index_per_contact[0] == "0" + + +def test_select_probes_single_probe(): + """Selecting a single probe keeps a single probe with its contact ids.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_probes(["1"]) + assert len(sub.probes) == 1 + assert sub.probe_ids == ["1"] + assert all(cid.startswith("p1") for cid in sub.get_global_contact_ids()) + + +def test_select_probes_preserves_probe_ids(): + """The selected ProbeGroup keeps the requested probe ids.""" + pg = _probegroup_with_contact_ids(unique=False) + sub = pg.select_probes(["2", "0"]) + assert set(sub.probe_ids) == {"0", "2"} + + +def test_select_probes_preserves_positions(): + """Contacts of the selected probes keep their global positions.""" + pg = _probegroup_with_contact_ids(unique=True) + + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + + sub = pg.select_probes(["0", "2"]) + sub_ids = sub.get_global_contact_ids() + sub_positions = sub.get_global_contact_positions() + for cid, pos in zip(sub_ids, sub_positions): + np.testing.assert_array_equal(pos, all_positions[all_ids == cid][0]) + + +def test_select_probes_none_raises(): + """Calling select_probes without probe_ids raises a ValueError.""" + pg = _probegroup_with_contact_ids(unique=False) + with pytest.raises(ValueError): + pg.select_probes(None) + + +def test_select_probes_all_probes(): + """Selecting all probes returns the whole ProbeGroup.""" + pg = _probegroup_with_contact_ids(unique=True) + sub = pg.select_probes(["0", "1", "2"]) + assert sub.get_contact_count() == pg.get_contact_count() + assert len(sub.probes) == len(pg.probes) + + +def test_select_contacts_duplicated_ids_raises(): + """Passing the same contact id more than once raises a ValueError.""" + pg = _probegroup_with_contact_ids(unique=True) + with pytest.raises(ValueError): + pg.select_contacts(["p0c0", "p0c1", "p0c0"]) + + +def test_select_contacts_preserves_order_in_array(): + """Selected contacts keep the order specified in the input array.""" + pg = _probegroup_with_contact_ids(unique=True) + contact_ids_list = [ + ["p0c1", "p0c0", "p2c5"], + ["p2c5", "p0c0", "p0c1"], + ["p0c1", "p2c5", "p0c0",] + ] + for selected_ids in contact_ids_list: + sub = pg.select_contacts(selected_ids) + contact_vector = sub.to_numpy(complete=True) + sub_ids = contact_vector["contact_ids"] + assert list(sub_ids) == selected_ids + + +def test_select_contacts_preserves_positions(): + """Selected contacts keep their global positions.""" + pg = _probegroup_with_contact_ids(unique=True) + selected_ids = ["p0c0", "p0c1", "p2c5"] + + all_ids = pg.get_global_contact_ids() + all_positions = pg.get_global_contact_positions() + expected = np.vstack([all_positions[all_ids == cid] for cid in selected_ids]) + + sub = pg.select_contacts(selected_ids) + sub_ids = sub.get_global_contact_ids() + sub_positions = sub.get_global_contact_positions() + got = np.vstack([sub_positions[sub_ids == cid] for cid in selected_ids]) + + np.testing.assert_array_equal(got, expected) + + +# ── add_probe : default probe_id generation ───────────────────────────────── + + +def test_add_probe_default_id_does_not_recycle_after_gap(): + """ + The default probe_id must not collide with an existing id after a selection + leaves a gap in the numeric ids. Using ``len(self._probes)`` would point back + at an id that is still in use; ``max(numeric ids) + 1`` is gap-proof. + """ + pg = ProbeGroup() + for _ in range(3): + pg.add_probe(generate_dummy_probe()) + assert pg.probe_ids == ["0", "1", "2"] + + # drop the middle probe -> ids become ["0", "2"], len is 2 (would collide with "2") + sub = pg.select_probes(["0", "2"]) + assert sub.probe_ids == ["0", "2"] + + sub.add_probe(generate_dummy_probe()) + assert sub.probe_ids == ["0", "2", "3"] + + +def test_add_probe_default_id_with_non_numeric_ids(): + """ + With only non-numeric ids present, the generated id starts from "0" and can + never collide with a non-numeric name. + """ + pg = ProbeGroup() + pg.add_probe(generate_dummy_probe(), probe_id="left") + pg.add_probe(generate_dummy_probe(), probe_id="right") + + pg.add_probe(generate_dummy_probe()) + assert pg.probe_ids == ["left", "right", "0"] + + if __name__ == "__main__": probegroup = _make_probegroup() # test_probegroup(probegroup) # test_probegroup_3d() - test_reordred_probegroup(probegroup) + test_reordered_probegroup(probegroup) From b1aef5b8ea9c049a619fa1031496b8a6cf1bd814 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Jun 2026 13:35:57 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/__init__.py | 6 +----- tests/test_probegroup.py | 6 +++++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index b638b0c3..b15e7d78 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -53,8 +53,4 @@ cache_full_library, clear_cache, ) -from .wiring import ( - get_available_pathways, - get_pathway, - wire_probe -) +from .wiring import get_available_pathways, get_pathway, wire_probe diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index deeb9e9e..b457f903 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -610,7 +610,11 @@ def test_select_contacts_preserves_order_in_array(): contact_ids_list = [ ["p0c1", "p0c0", "p2c5"], ["p2c5", "p0c0", "p0c1"], - ["p0c1", "p2c5", "p0c0",] + [ + "p0c1", + "p2c5", + "p0c0", + ], ] for selected_ids in contact_ids_list: sub = pg.select_contacts(selected_ids)