|
| 1 | +"""Dummy similarity seach estimator.""" |
| 2 | + |
| 3 | +__author__ = ["baraline"] |
| 4 | +__all__ = ["DummySimilaritySearch"] |
| 5 | + |
| 6 | + |
| 7 | +from aeon.similarity_search.base import BaseSimiliaritySearch |
| 8 | + |
| 9 | + |
| 10 | +class DummySimilaritySearch(BaseSimiliaritySearch): |
| 11 | + """ |
| 12 | + DummySimilaritySearch for testing of the BaseSimiliaritySearch class. |
| 13 | +
|
| 14 | + Parameters |
| 15 | + ---------- |
| 16 | + distance : str, default ="euclidean" |
| 17 | + Name of the distance function to use. |
| 18 | + normalize : bool, default = False |
| 19 | + Whether the distance function should be z-normalized. |
| 20 | + store_distance_profile : bool, default = =False. |
| 21 | + Whether to store the computed distance profile in the attribute |
| 22 | + "_distance_profile" after calling the predict method. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, distance="euclidean", normalize=False, store_distance_profile=False |
| 27 | + ): |
| 28 | + super(DummySimilaritySearch, self).__init__( |
| 29 | + distance=distance, |
| 30 | + normalize=normalize, |
| 31 | + store_distance_profile=store_distance_profile, |
| 32 | + ) |
| 33 | + |
| 34 | + def _fit(self, X, y): |
| 35 | + """ |
| 36 | + Private fit method, does nothing more than the base class. |
| 37 | +
|
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + X : array, shape (n_instances, n_channels, n_timestamps) |
| 41 | + Input array to used as database for the similarity search |
| 42 | + y : optional |
| 43 | + Not used. |
| 44 | +
|
| 45 | + Returns |
| 46 | + ------- |
| 47 | + self |
| 48 | +
|
| 49 | + """ |
| 50 | + return self |
| 51 | + |
| 52 | + def _predict(self, q, mask): |
| 53 | + """ |
| 54 | + Private predict method for DummySimilaritySearch. |
| 55 | +
|
| 56 | + It compute the distance profiles and then returns the best match |
| 57 | +
|
| 58 | + Parameters |
| 59 | + ---------- |
| 60 | + q : array, shape (n_channels, q_length) |
| 61 | + Input query used for similarity search. |
| 62 | + mask : array, shape (n_instances, n_channels, n_timestamps - (q_length - 1)) |
| 63 | + Boolean mask of the shape of the distance profile indicating for which part |
| 64 | + of it the distance should be computed. |
| 65 | +
|
| 66 | + Returns |
| 67 | + ------- |
| 68 | + array |
| 69 | + An array containing the index of the best match between q and _X. |
| 70 | +
|
| 71 | + """ |
| 72 | + if self.normalize: |
| 73 | + distance_profile = self.distance_profile_function( |
| 74 | + self._X, |
| 75 | + q, |
| 76 | + mask, |
| 77 | + self._X_means, |
| 78 | + self._X_stds, |
| 79 | + self._q_means, |
| 80 | + self._q_stds, |
| 81 | + ) |
| 82 | + else: |
| 83 | + distance_profile = self.distance_profile_function(self._X, q, mask) |
| 84 | + |
| 85 | + if self.store_distance_profile: |
| 86 | + self._distance_profile = distance_profile |
| 87 | + |
| 88 | + # For now, deal with the multidimensional case as "dependent", so we sum. |
| 89 | + search_size = distance_profile.shape[-1] |
| 90 | + distance_profile = distance_profile.sum(axis=1) |
| 91 | + _id_best = distance_profile.argmin(axis=None) |
| 92 | + |
| 93 | + return [(_id_best // search_size, _id_best % search_size)] |
0 commit comments