Skip to content

Commit b6d322f

Browse files
committed
Adding dummy class and test, correcting some docstrings
1 parent b0a38ba commit b6d322f

File tree

5 files changed

+163
-7
lines changed

5 files changed

+163
-7
lines changed

aeon/similarity_search/_dummy.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Dummy similarity seach estimator."""
2+
3+
__author__ = ["baraline"]
4+
__all__ = ["DummySimilaritySearch"]
5+
6+
7+
from aeon.similarity_search.base import BaseSimiliaritySearch
8+
9+
10+
class DummySimilaritySearch(BaseSimiliaritySearch):
11+
"""
12+
DummySimilaritySearch for testing of the BaseSimiliaritySearch class.
13+
14+
Parameters
15+
----------
16+
distance : str, default ="euclidean"
17+
Name of the distance function to use.
18+
normalize : bool, default = False
19+
Whether the distance function should be z-normalized.
20+
store_distance_profile : bool, default = =False.
21+
Whether to store the computed distance profile in the attribute
22+
"_distance_profile" after calling the predict method.
23+
"""
24+
25+
def __init__(
26+
self, distance="euclidean", normalize=False, store_distance_profile=False
27+
):
28+
super(DummySimilaritySearch, self).__init__(
29+
distance=distance,
30+
normalize=normalize,
31+
store_distance_profile=store_distance_profile,
32+
)
33+
34+
def _fit(self, X, y):
35+
"""
36+
Private fit method, does nothing more than the base class.
37+
38+
Parameters
39+
----------
40+
X : array, shape (n_instances, n_channels, n_timestamps)
41+
Input array to used as database for the similarity search
42+
y : optional
43+
Not used.
44+
45+
Returns
46+
-------
47+
self
48+
49+
"""
50+
return self
51+
52+
def _predict(self, q, mask):
53+
"""
54+
Private predict method for DummySimilaritySearch.
55+
56+
It compute the distance profiles and then returns the best match
57+
58+
Parameters
59+
----------
60+
q : array, shape (n_channels, q_length)
61+
Input query used for similarity search.
62+
mask : array, shape (n_instances, n_channels, n_timestamps - (q_length - 1))
63+
Boolean mask of the shape of the distance profile indicating for which part
64+
of it the distance should be computed.
65+
66+
Returns
67+
-------
68+
array
69+
An array containing the index of the best match between q and _X.
70+
71+
"""
72+
if self.normalize:
73+
distance_profile = self.distance_profile_function(
74+
self._X,
75+
q,
76+
mask,
77+
self._X_means,
78+
self._X_stds,
79+
self._q_means,
80+
self._q_stds,
81+
)
82+
else:
83+
distance_profile = self.distance_profile_function(self._X, q, mask)
84+
85+
if self.store_distance_profile:
86+
self._distance_profile = distance_profile
87+
88+
# For now, deal with the multidimensional case as "dependent", so we sum.
89+
search_size = distance_profile.shape[-1]
90+
distance_profile = distance_profile.sum(axis=1)
91+
_id_best = distance_profile.argmin(axis=None)
92+
93+
return [(_id_best // search_size, _id_best % search_size)]

aeon/similarity_search/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class BaseSimiliaritySearch(BaseEstimator, ABC):
2020
"""BaseSimilaritySearch.
2121
22-
Attributes
22+
Parameters
2323
----------
2424
distance : str, default ="euclidean"
2525
Name of the distance function to use.
@@ -28,6 +28,15 @@ class BaseSimiliaritySearch(BaseEstimator, ABC):
2828
store_distance_profile : bool, default = False.
2929
Whether to store the computed distance profile in the attribute
3030
"_distance_profile" after calling the predict method.
31+
32+
Attributes
33+
----------
34+
_X : array, shape (n_instances, n_channels, n_timestamps)
35+
The input time series stored during the fit method.
36+
distance_profile_function : function
37+
The function used to compute the distance profile affected
38+
during the fit method based on the distance and normalize
39+
parameters.
3140
"""
3241

3342
_tags = {
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Tests for DummySimilaritySearch."""
2+
3+
__author__ = ["baraline"]
4+
5+
6+
import numpy as np
7+
import pytest
8+
from numpy.testing import assert_array_equal
9+
10+
from aeon.similarity_search._dummy import DummySimilaritySearch
11+
12+
DATATYPES = ["int64", "float64"]
13+
14+
15+
@pytest.mark.parametrize("dtype", DATATYPES)
16+
def test_DummySimilaritySearch(dtype):
17+
X = np.asarray(
18+
[[[1, 2, 3, 4, 5, 6, 7, 8]], [[1, 2, 4, 4, 5, 6, 5, 4]]], dtype=dtype
19+
)
20+
q = np.asarray([[3, 4, 5]], dtype=dtype)
21+
22+
search = DummySimilaritySearch()
23+
search.fit(X)
24+
idx = search.predict(q)
25+
assert_array_equal(idx, [(0, 2)])
26+
27+
search = DummySimilaritySearch(normalize=True)
28+
search.fit(X)
29+
q = np.asarray([[8, 8, 10]], dtype=dtype)
30+
idx = search.predict(q)
31+
assert_array_equal(idx, [(1, 2)])
32+
33+
search = DummySimilaritySearch(normalize=True)
34+
search.fit(X)
35+
idx = search.predict(q, q_index=(1, 2))
36+
assert_array_equal(idx, [(1, 0)])

aeon/similarity_search/tests/test_top_k_similarity.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
"""
2-
Created on Sat Sep 9 14:12:58 2023
3-
4-
@author: antoi
5-
"""
1+
"""Tests for TopKSimilaritySearch."""
62

3+
__author__ = ["baraline"]
74

85
import numpy as np
96
import pytest

aeon/similarity_search/top_k_similarity.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class TopKSimilaritySearch(BaseSimiliaritySearch):
1111
1212
Finds the closest k series to the query series based on a distance function.
1313
14-
Attributes
14+
Parameters
1515
----------
1616
k : int, default=1
1717
The number of nearest matches from Q to return.
@@ -22,6 +22,27 @@ class TopKSimilaritySearch(BaseSimiliaritySearch):
2222
store_distance_profile : bool, default = =False.
2323
Whether to store the computed distance profile in the attribute
2424
"_distance_profile" after calling the predict method.
25+
26+
Attributes
27+
----------
28+
_X : array, shape (n_instances, n_channels, n_timestamps)
29+
The input time series stored during the fit method.
30+
distance_profile_function : function
31+
The function used to compute the distance profile affected
32+
during the fit method based on the distance and normalize
33+
parameters.
34+
35+
Examples
36+
--------
37+
>>> from aeon.similarity_search import TopKSimilaritySearch
38+
>>> from aeon.datasets import load_unit_test
39+
>>> X_train, y_train = load_unit_test(split="train")
40+
>>> X_test, y_test = load_unit_test(split="test")
41+
>>> clf = TopKSimilaritySearch(k=1)
42+
>>> clf.fit(X_train, y_train)
43+
TopKSimilaritySearch(...)
44+
>>> q = X_test[0, :, 5:15]
45+
>>> y_pred = clf.predict(q)
2546
"""
2647

2748
def __init__(

0 commit comments

Comments
 (0)