-
Notifications
You must be signed in to change notification settings - Fork 222
[ENH] New similarity search module #724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
ba32f9d
base similarity search
TonyBagnall 1ad9eca
slow search example
TonyBagnall e798d68
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into expe…
MatthewMiddlehurst f5fc6b7
[ENH] Similarity search base class and TopK search with naïve Euclide…
baraline b6491be
Merge branch 'main' into experimental/similarity_search
TonyBagnall 5ed662f
Merge branch 'experimental/similarity_search' of https://github.com/a…
TonyBagnall 015dc9f
Merge branch 'main' into experimental/similarity_search
TonyBagnall c258de7
format
TonyBagnall 1141032
add init
TonyBagnall ac04c74
Merge branch 'main' into experimental/similarity_search
TonyBagnall ae4b49a
call constructor
TonyBagnall 98f95db
Merge branch 'main' into experimental/similarity_search
TonyBagnall a64e29b
add similarity base to register
TonyBagnall 1a1858b
add similarity-search to tagging
TonyBagnall 33557bc
Bugfixes for constant case and input alteration during normalization
baraline ba70e87
Merge branch 'main' into experimental/similarity_search
TonyBagnall cf72421
typo
TonyBagnall fbda755
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] f440c3e
typo
TonyBagnall 2ea98a6
Merge branch 'main' into experimental/similarity_search
TonyBagnall 1f74035
Merge branch 'experimental/similarity_search' of https://github.com/a…
TonyBagnall 9fcd4d3
docstrings
TonyBagnall 55ebc86
docstrings
TonyBagnall bc75368
docstrings
TonyBagnall 7398eac
docstrings
TonyBagnall c7d927f
docstrings
TonyBagnall c5b4a33
Fixing typos
baraline 6923f5d
Merge branch 'experimental/similarity_search' of https://github.com/a…
baraline bc51503
Adding some docs, adding base class arguments to topk, more expressiv…
baraline 6bbe528
Change notation of query from Q to q
baraline 7b40df9
Merge branch 'main' into experimental/similarity_search
TonyBagnall 9fe08b5
Merge branch 'main' into experimental/similarity_search
TonyBagnall ed746d5
Adding example notebook and module img, updating docs and correcting …
baraline c11fd32
Merge branch 'experimental/similarity_search' of https://github.com/a…
baraline 810de65
Adding parameters for self matches, typos in example notebook
baraline 23f29cf
typo in import, replace Q with q
baraline 787fe10
switch test example for pipeline
TonyBagnall 174fff5
switch test example for pipeline
TonyBagnall 10f79f2
Merge branch 'main' of https://github.com/aeon-toolkit/aeon
TonyBagnall 2c66919
Add mask to distance profile, move exclusion zoneto base class, some …
baraline a63c21c
Merge branch 'main' into experimental/similarity_search
TonyBagnall b310735
Add distance profile and speedups notebooks, exclusion factor value c…
baraline e4d4b3e
Merge branch 'main' of https://github.com/aeon-toolkit/aeon
TonyBagnall bd75ab3
Merge branch 'main' of https://github.com/aeon-toolkit/aeon
TonyBagnall dc4d82c
Merge branch 'main' of https://github.com/aeon-toolkit/aeon
TonyBagnall 8d4a3bd
Merge branch 'main' into experimental/similarity_search
baraline e0c82fd
Fixing tests and docs that where not updated after previous changes
baraline a625f9f
Force float convertion of input to avoid issues with normalization of…
baraline 610ac00
Merge branch 'main' into experimental/similarity_search
TonyBagnall 60027c5
Merge branch 'main' of https://github.com/aeon-toolkit/aeon
TonyBagnall b0a38ba
Merge branch 'main' into experimental/similarity_search
TonyBagnall b6d322f
Adding dummy class and test, correcting some docstrings
baraline 6771ead
Fixes from Matthew review
baraline File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """BaseSimilaritySearch.""" | ||
|
|
||
| __author__ = ["baraline"] | ||
| __all__ = ["BaseSimiliaritySearch", "TopKSimilaritySearch"] | ||
|
|
||
| from aeon.similarity_search.base import BaseSimiliaritySearch | ||
| from aeon.similarity_search.top_k_similarity import TopKSimilaritySearch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| """Dummy similarity seach estimator.""" | ||
|
|
||
| __author__ = ["baraline"] | ||
| __all__ = ["DummySimilaritySearch"] | ||
|
|
||
|
|
||
| from aeon.similarity_search.base import BaseSimiliaritySearch | ||
|
|
||
|
|
||
| class DummySimilaritySearch(BaseSimiliaritySearch): | ||
| """ | ||
| DummySimilaritySearch for testing of the BaseSimiliaritySearch class. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| distance : str, default ="euclidean" | ||
| Name of the distance function to use. | ||
| normalize : bool, default = False | ||
| Whether the distance function should be z-normalized. | ||
| store_distance_profile : bool, default = =False. | ||
| Whether to store the computed distance profile in the attribute | ||
| "_distance_profile" after calling the predict method. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| _X : array, shape (n_instances, n_channels, n_timestamps) | ||
| The input time series stored during the fit method. | ||
| distance_profile_function : function | ||
| The function used to compute the distance profile affected | ||
| during the fit method based on the distance and normalize | ||
| parameters. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from aeon.similarity_search._dummy import DummySimilaritySearch | ||
| >>> from aeon.datasets import load_unit_test | ||
| >>> X_train, y_train = load_unit_test(split="train") | ||
| >>> X_test, y_test = load_unit_test(split="test") | ||
| >>> clf = DummySimilaritySearch() | ||
| >>> clf.fit(X_train, y_train) | ||
| DummySimilaritySearch(...) | ||
| >>> q = X_test[0, :, 5:15] | ||
| >>> y_pred = clf.predict(q) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, distance="euclidean", normalize=False, store_distance_profile=False | ||
| ): | ||
| super(DummySimilaritySearch, self).__init__( | ||
| distance=distance, | ||
| normalize=normalize, | ||
| store_distance_profile=store_distance_profile, | ||
| ) | ||
|
|
||
| def _fit(self, X, y): | ||
| """ | ||
| Private fit method, does nothing more than the base class. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : array, shape (n_instances, n_channels, n_timestamps) | ||
| Input array to used as database for the similarity search | ||
| y : optional | ||
| Not used. | ||
|
|
||
| Returns | ||
| ------- | ||
| self | ||
|
|
||
| """ | ||
| return self | ||
|
|
||
| def _predict(self, q, mask): | ||
| """ | ||
| Private predict method for DummySimilaritySearch. | ||
|
|
||
| It compute the distance profiles and then returns the best match | ||
|
|
||
| Parameters | ||
| ---------- | ||
| q : array, shape (n_channels, q_length) | ||
| Input query used for similarity search. | ||
| mask : array, shape (n_instances, n_channels, n_timestamps - (q_length - 1)) | ||
| Boolean mask of the shape of the distance profile indicating for which part | ||
| of it the distance should be computed. | ||
|
|
||
| Returns | ||
| ------- | ||
| array | ||
| An array containing the index of the best match between q and _X. | ||
|
|
||
| """ | ||
| if self.normalize: | ||
| distance_profile = self.distance_profile_function( | ||
| self._X, | ||
| q, | ||
| mask, | ||
| self._X_means, | ||
| self._X_stds, | ||
| self._q_means, | ||
| self._q_stds, | ||
| ) | ||
| else: | ||
| distance_profile = self.distance_profile_function(self._X, q, mask) | ||
|
|
||
| if self.store_distance_profile: | ||
| self._distance_profile = distance_profile | ||
|
|
||
| # For now, deal with the multidimensional case as "dependent", so we sum. | ||
| search_size = distance_profile.shape[-1] | ||
| distance_profile = distance_profile.sum(axis=1) | ||
| _id_best = distance_profile.argmin(axis=None) | ||
|
|
||
| return [(_id_best // search_size, _id_best % search_size)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| """Base class for similarity search.""" | ||
|
|
||
| __author__ = ["baraline"] | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Iterable | ||
| from typing import final | ||
|
|
||
| import numpy as np | ||
|
|
||
| from aeon.base import BaseEstimator | ||
| from aeon.similarity_search.distance_profiles import ( | ||
| naive_euclidean_profile, | ||
| normalized_naive_euclidean_profile, | ||
| ) | ||
| from aeon.utils.numba.general import sliding_mean_std_one_series | ||
|
|
||
|
|
||
| class BaseSimiliaritySearch(BaseEstimator, ABC): | ||
| """ | ||
| BaseSimilaritySearch. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| distance : str, default ="euclidean" | ||
| Name of the distance function to use. | ||
| normalize : bool, default = False | ||
| Whether the distance function should be z-normalized. | ||
| store_distance_profile : bool, default = False. | ||
| Whether to store the computed distance profile in the attribute | ||
| "_distance_profile" after calling the predict method. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| _X : array, shape (n_instances, n_channels, n_timestamps) | ||
| The input time series stored during the fit method. | ||
| distance_profile_function : function | ||
| The function used to compute the distance profile affected | ||
| during the fit method based on the distance and normalize | ||
| parameters. | ||
| """ | ||
|
|
||
| _tags = { | ||
| "capability:multivariate": True, | ||
| "capability:missing_values": False, | ||
| } | ||
|
|
||
| def __init__( | ||
| self, distance="euclidean", normalize=False, store_distance_profile=False | ||
| ): | ||
| self.distance = distance | ||
| self.normalize = normalize | ||
| self.store_distance_profile = store_distance_profile | ||
| super(BaseSimiliaritySearch, self).__init__() | ||
|
|
||
| def _get_distance_profile_function(self): | ||
| dist_profile = DISTANCE_PROFILE_DICT.get(self.distance) | ||
| if dist_profile is None: | ||
| raise ValueError( | ||
| f"Unknown or unsupported distance profile function {dist_profile}" | ||
| ) | ||
| return dist_profile[self.normalize] | ||
|
|
||
| def _store_mean_std_from_inputs(self, q_length): | ||
| n_instances, n_channels, X_length = self._X.shape | ||
| search_space_size = X_length - q_length + 1 | ||
|
|
||
| means = np.zeros((n_instances, n_channels, search_space_size)) | ||
| stds = np.zeros((n_instances, n_channels, search_space_size)) | ||
|
|
||
| for i in range(n_instances): | ||
| _mean, _std = sliding_mean_std_one_series(self._X[i], q_length, 1) | ||
| stds[i] = _std | ||
| means[i] = _mean | ||
|
|
||
| self._X_means = means | ||
| self._X_stds = stds | ||
|
|
||
| @final | ||
| def fit(self, X, y=None): | ||
| """ | ||
| Fit method: store the input data and get the distance profile function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : array, shape (n_instances, n_channels, n_timestamps) | ||
| Input array to used as database for the similarity search | ||
| y : optional | ||
| Not used. | ||
|
|
||
| Raises | ||
| ------ | ||
| TypeError | ||
| If the input X array is not 3D raise an error. | ||
|
|
||
| Returns | ||
| ------- | ||
| self | ||
|
|
||
| """ | ||
| # For now force (n_instances, n_channels, n_timestamps), we could convert 2D | ||
| # (n_channels, n_timestamps) to 3D with a warning | ||
| if not isinstance(X, np.ndarray) or X.ndim != 3: | ||
| raise TypeError( | ||
| "Error, only supports 3D numpy of shape" | ||
| "(n_instances, n_channels, n_timestamps)." | ||
| ) | ||
|
|
||
| # Get distance function | ||
| self.distance_profile_function = self._get_distance_profile_function() | ||
|
|
||
| self._X = X.astype(float) | ||
| self._fit(X, y) | ||
| return self | ||
|
|
||
| @final | ||
| def predict(self, q, q_index=None, exclusion_factor=2.0): | ||
| """ | ||
| Predict method: Check the shape of q and call _predict to perform the search. | ||
|
|
||
| If the distance profile function is normalized, it stores the mean and stds | ||
| from q and _X. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| q : array, shape (n_channels, q_length) | ||
| Input query used for similarity search. | ||
| q_index : Iterable, default=None | ||
| An Interable (tuple, list, array) used to specify the index of Q if it is | ||
| extracted from the input data X given during the fit method. | ||
| Given the tuple (id_sample, id_timestamp), the similarity search will define | ||
| an exclusion zone around the q_index in order to avoid matching q with | ||
| itself. If None, it is considered that the query is not extracted from X. | ||
| exclusion_factor : float, default=2. | ||
| The factor to apply to the query length to define the exclusion zone. The | ||
| exclusion zone is define from id_timestamp - q_length//exclusion_factor to | ||
| id_timestamp + q_length//exclusion_factor | ||
|
|
||
| Raises | ||
| ------ | ||
| TypeError | ||
| If the input q array is not 2D raise an error. | ||
| ValueError | ||
| If the length of the query is greater | ||
|
|
||
| Returns | ||
| ------- | ||
| array | ||
| An array containing the indexes of the matches between q and _X. | ||
| The decision of wheter a candidate of size q_length from _X is matched with | ||
| Q depends on the subclasses that implent the _predict method | ||
| (e.g. top-k, threshold, ...). | ||
|
|
||
| """ | ||
| if not isinstance(q, np.ndarray) or q.ndim != 2: | ||
| raise TypeError( | ||
| "Error, only supports 2D numpy atm. If q is univariate" | ||
| " do q.reshape(1,-1)." | ||
| ) | ||
|
|
||
| q_dim, q_length = q.shape | ||
| if q_length >= self._X.shape[-1]: | ||
| raise ValueError( | ||
| "The length of the query should be inferior or equal to the length of" | ||
| "data (X) provided during fit, but got {} for q and {} for X".format( | ||
| q_length, self._X.shape[-1] | ||
| ) | ||
| ) | ||
|
|
||
| if q_dim != self._X.shape[1]: | ||
| raise ValueError( | ||
| "The number of feature should be the same for the query q and the data" | ||
| "(X) provided during fit, but got {} for q and {} for X".format( | ||
| q_dim, self._X.shape[1] | ||
| ) | ||
| ) | ||
|
|
||
| n_instances, _, n_timestamps = self._X.shape | ||
| mask = np.ones((n_instances, q_dim, n_timestamps), dtype=bool) | ||
|
|
||
| if q_index is not None: | ||
| if isinstance(q_index, Iterable): | ||
| if len(q_index) != 2: | ||
| raise ValueError( | ||
| "The q_index should contain an interable of size 2 such as" | ||
| "(id_sample, id_timestamp), but got an iterable of" | ||
| "size {}".format(len(q_index)) | ||
| ) | ||
| else: | ||
| raise TypeError( | ||
| "If not None, the q_index parameter should be an iterable, here" | ||
| " q_index is of type {}".format(type(q_index)) | ||
| ) | ||
|
|
||
| if exclusion_factor <= 0: | ||
| raise ValueError( | ||
| "The value of exclusion_factor should be superior to 0, but got" | ||
| "{}".format(len(exclusion_factor)) | ||
| ) | ||
|
|
||
| i_instance, i_timestamp = q_index | ||
| profile_length = n_timestamps - (q_length - 1) | ||
| exclusion_LB = max(0, int(i_timestamp - q_length // exclusion_factor)) | ||
| exclusion_UB = min( | ||
| profile_length, int(i_timestamp + q_length // exclusion_factor) | ||
| ) | ||
| mask[i_instance, :, exclusion_LB:exclusion_UB] = False | ||
|
|
||
| if self.normalize: | ||
| self._q_means = np.mean(q, axis=-1) | ||
| self._q_stds = np.std(q, axis=-1) | ||
| self._store_mean_std_from_inputs(q_length) | ||
|
|
||
| return self._predict(q.astype(float), mask) | ||
|
|
||
| @abstractmethod | ||
| def _fit(self, X, y): | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def _predict(self, q): | ||
| ... | ||
|
|
||
|
|
||
| # Dictionary structure : | ||
| # 1st lvl key : distance function used | ||
| # 2nd lvl key : boolean indicating whether distance is normalized | ||
| DISTANCE_PROFILE_DICT = { | ||
| "euclidean": { | ||
| True: normalized_naive_euclidean_profile, | ||
| False: naive_euclidean_profile, | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.