5
5
from collections import OrderedDict
6
6
from copy import copy
7
7
from math import sqrt
8
+ from typing import Callable, Iterable
8
9
9
10
import cloudpickle
10
11
import numpy as np
11
12
from scipy import interpolate
13
+ from scipy.interpolate.interpnd import LinearNDInterpolator
12
14
13
15
from adaptive.learner.base_learner import BaseLearner
14
16
from adaptive.learner.triangulation import simplex_volume_in_embedding
15
17
from adaptive.notebook_integration import ensure_holoviews
18
+ from adaptive.types import Bool, Float, Real
16
19
from adaptive.utils import (
17
20
assign_defaults,
18
21
cache_latest,
30
33
# Learner2D and helper functions.
31
34
32
35
33
- def deviations(ip) :
36
+ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray] :
34
37
"""Returns the deviation of the linear estimate.
35
38
36
39
Is useful when defining custom loss functions.
@@ -68,7 +71,7 @@ def deviation(p, v, g):
68
71
return devs
69
72
70
73
71
- def areas(ip) :
74
+ def areas(ip: LinearNDInterpolator) -> np.ndarray :
72
75
"""Returns the area per triangle of the triangulation inside
73
76
a `LinearNDInterpolator` instance.
74
77
@@ -89,7 +92,7 @@ def areas(ip):
89
92
return areas
90
93
91
94
92
- def uniform_loss(ip) :
95
+ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray :
93
96
"""Loss function that samples the domain uniformly.
94
97
95
98
Works with `~adaptive.Learner2D` only.
@@ -120,7 +123,9 @@ def uniform_loss(ip):
120
123
return np.sqrt(areas(ip))
121
124
122
125
123
- def resolution_loss_function(min_distance=0, max_distance=1):
126
+ def resolution_loss_function(
127
+ min_distance: float = 0, max_distance: float = 1
128
+ ) -> Callable[[LinearNDInterpolator], np.ndarray]:
124
129
"""Loss function that is similar to the `default_loss` function, but you
125
130
can set the maximimum and minimum size of a triangle.
126
131
@@ -159,7 +164,7 @@ def resolution_loss(ip):
159
164
return resolution_loss
160
165
161
166
162
- def minimize_triangle_surface_loss(ip) :
167
+ def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray :
163
168
"""Loss function that is similar to the distance loss function in the
164
169
`~adaptive.Learner1D`. The loss is the area spanned by the 3D
165
170
vectors of the vertices.
@@ -205,7 +210,7 @@ def _get_vectors(points):
205
210
return np.linalg.norm(np.cross(a, b) / 2, axis=1)
206
211
207
212
208
- def default_loss(ip) :
213
+ def default_loss(ip: LinearNDInterpolator) -> np.ndarray :
209
214
"""Loss function that combines `deviations` and `areas` of the triangles.
210
215
211
216
Works with `~adaptive.Learner2D` only.
@@ -225,7 +230,7 @@ def default_loss(ip):
225
230
return losses
226
231
227
232
228
- def choose_point_in_triangle(triangle, max_badness) :
233
+ def choose_point_in_triangle(triangle: np.ndarray , max_badness: int) -> np.ndarray :
229
234
"""Choose a new point in inside a triangle.
230
235
231
236
If the ratio of the longest edge of the triangle squared
@@ -364,7 +369,12 @@ class Learner2D(BaseLearner):
364
369
over each triangle.
365
370
"""
366
371
367
- def __init__(self, function, bounds, loss_per_triangle=None):
372
+ def __init__(
373
+ self,
374
+ function: Callable,
375
+ bounds: tuple[tuple[Real, Real], tuple[Real, Real]],
376
+ loss_per_triangle: Callable | None = None,
377
+ ) -> None:
368
378
self.ndim = len(bounds)
369
379
self._vdim = None
370
380
self.loss_per_triangle = loss_per_triangle or default_loss
@@ -379,7 +389,7 @@ def __init__(self, function, bounds, loss_per_triangle=None):
379
389
380
390
self._bounds_points = list(itertools.product(*bounds))
381
391
self._stack.update({p: np.inf for p in self._bounds_points})
382
- self.function = function
392
+ self.function = function # type: ignore
383
393
self._ip = self._ip_combined = None
384
394
385
395
self.stack_size = 10
@@ -388,7 +398,7 @@ def new(self) -> Learner2D:
388
398
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
389
399
390
400
@property
391
- def xy_scale(self):
401
+ def xy_scale(self) -> np.ndarray :
392
402
xy_scale = self._xy_scale
393
403
if self.aspect_ratio == 1:
394
404
return xy_scale
@@ -486,21 +496,21 @@ def load_dataframe(
486
496
self.function, df, function_prefix
487
497
)
488
498
489
- def _scale(self, points) :
499
+ def _scale(self, points: list[tuple[float, float]] | np.ndarray) -> np.ndarray :
490
500
points = np.asarray(points, dtype=float)
491
501
return (points - self.xy_mean) / self.xy_scale
492
502
493
- def _unscale(self, points) :
503
+ def _unscale(self, points: np.ndarray) -> np.ndarray :
494
504
points = np.asarray(points, dtype=float)
495
505
return points * self.xy_scale + self.xy_mean
496
506
497
507
@property
498
- def npoints(self):
508
+ def npoints(self) -> int :
499
509
"""Number of evaluated points."""
500
510
return len(self.data)
501
511
502
512
@property
503
- def vdim(self):
513
+ def vdim(self) -> int :
504
514
"""Length of the output of ``learner.function``.
505
515
If the output is unsized (when it's a scalar)
506
516
then `vdim = 1`.
@@ -516,12 +526,14 @@ def vdim(self):
516
526
return self._vdim or 1
517
527
518
528
@property
519
- def bounds_are_done(self):
529
+ def bounds_are_done(self) -> bool :
520
530
return not any(
521
531
(p in self.pending_points or p in self._stack) for p in self._bounds_points
522
532
)
523
533
524
- def interpolated_on_grid(self, n=None):
534
+ def interpolated_on_grid(
535
+ self, n: int = None
536
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
525
537
"""Get the interpolated data on a grid.
526
538
527
539
Parameters
@@ -553,7 +565,7 @@ def interpolated_on_grid(self, n=None):
553
565
xs, ys = self._unscale(np.vstack([xs, ys]).T).T
554
566
return xs, ys, zs
555
567
556
- def _data_in_bounds(self):
568
+ def _data_in_bounds(self) -> tuple[np.ndarray, np.ndarray] :
557
569
if self.data:
558
570
points = np.array(list(self.data.keys()))
559
571
values = np.array(list(self.data.values()), dtype=float)
@@ -562,7 +574,7 @@ def _data_in_bounds(self):
562
574
return points[inds], values[inds].reshape(-1, self.vdim)
563
575
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
564
576
565
- def _data_interp(self):
577
+ def _data_interp(self) -> tuple[np.ndarray | list[tuple[float, float]], np.ndarray] :
566
578
if self.pending_points:
567
579
points = list(self.pending_points)
568
580
if self.bounds_are_done:
@@ -575,7 +587,7 @@ def _data_interp(self):
575
587
return points, values
576
588
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
577
589
578
- def _data_combined(self):
590
+ def _data_combined(self) -> tuple[np.ndarray, np.ndarray] :
579
591
points, values = self._data_in_bounds()
580
592
if not self.pending_points:
581
593
return points, values
@@ -584,7 +596,7 @@ def _data_combined(self):
584
596
values_combined = np.vstack([values, values_interp])
585
597
return points_combined, values_combined
586
598
587
- def ip(self):
599
+ def ip(self) -> LinearNDInterpolator :
588
600
"""Deprecated, use `self.interpolator(scaled=True)`"""
589
601
warnings.warn(
590
602
"`learner.ip()` is deprecated, use `learner.interpolator(scaled=True)`."
@@ -593,7 +605,7 @@ def ip(self):
593
605
)
594
606
return self.interpolator(scaled=True)
595
607
596
- def interpolator(self, *, scaled= False):
608
+ def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator :
597
609
"""A `scipy.interpolate.LinearNDInterpolator` instance
598
610
containing the learner's data.
599
611
@@ -624,7 +636,7 @@ def interpolator(self, *, scaled=False):
624
636
points, values = self._data_in_bounds()
625
637
return interpolate.LinearNDInterpolator(points, values)
626
638
627
- def _interpolator_combined(self):
639
+ def _interpolator_combined(self) -> LinearNDInterpolator :
628
640
"""A `scipy.interpolate.LinearNDInterpolator` instance
629
641
containing the learner's data *and* interpolated data of
630
642
the `pending_points`."""
@@ -634,12 +646,12 @@ def _interpolator_combined(self):
634
646
self._ip_combined = interpolate.LinearNDInterpolator(points, values)
635
647
return self._ip_combined
636
648
637
- def inside_bounds(self, xy) :
649
+ def inside_bounds(self, xy: tuple[float, float]) -> Bool :
638
650
x, y = xy
639
651
(xmin, xmax), (ymin, ymax) = self.bounds
640
652
return xmin <= x <= xmax and ymin <= y <= ymax
641
653
642
- def tell(self, point, value) :
654
+ def tell(self, point: tuple[float, float], value: float | Iterable[float]) -> None :
643
655
point = tuple(point)
644
656
self.data[point] = value
645
657
if not self.inside_bounds(point):
@@ -648,15 +660,17 @@ def tell(self, point, value):
648
660
self._ip = None
649
661
self._stack.pop(point, None)
650
662
651
- def tell_pending(self, point) :
663
+ def tell_pending(self, point: tuple[float, float]) -> None :
652
664
point = tuple(point)
653
665
if not self.inside_bounds(point):
654
666
return
655
667
self.pending_points.add(point)
656
668
self._ip_combined = None
657
669
self._stack.pop(point, None)
658
670
659
- def _fill_stack(self, stack_till=1):
671
+ def _fill_stack(
672
+ self, stack_till: int = 1
673
+ ) -> tuple[list[tuple[float, float]], list[float]]:
660
674
if len(self.data) + len(self.pending_points) < self.ndim + 1:
661
675
raise ValueError("too few points...")
662
676
@@ -695,7 +709,9 @@ def _fill_stack(self, stack_till=1):
695
709
696
710
return points_new, losses_new
697
711
698
- def ask(self, n, tell_pending=True):
712
+ def ask(
713
+ self, n: int, tell_pending: bool = True
714
+ ) -> tuple[list[tuple[float, float] | np.ndarray], list[float]]:
699
715
# Even if tell_pending is False we add the point such that _fill_stack
700
716
# will return new points, later we remove these points if needed.
701
717
points = list(self._stack.keys())
@@ -726,14 +742,14 @@ def ask(self, n, tell_pending=True):
726
742
return points[:n], loss_improvements[:n]
727
743
728
744
@cache_latest
729
- def loss(self, real= True):
745
+ def loss(self, real: bool = True) -> float :
730
746
if not self.bounds_are_done:
731
747
return np.inf
732
748
ip = self.interpolator(scaled=True) if real else self._interpolator_combined()
733
749
losses = self.loss_per_triangle(ip)
734
750
return losses.max()
735
751
736
- def remove_unfinished(self):
752
+ def remove_unfinished(self) -> None :
737
753
self.pending_points = set()
738
754
for p in self._bounds_points:
739
755
if p not in self.data:
@@ -807,10 +823,10 @@ def plot(self, n=None, tri_alpha=0):
807
823
808
824
return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
809
825
810
- def _get_data(self):
826
+ def _get_data(self) -> dict[tuple[float, float], Float | np.ndarray] :
811
827
return self.data
812
828
813
- def _set_data(self, data) :
829
+ def _set_data(self, data: dict[tuple[float, float], Float | np.ndarray]) -> None :
814
830
self.data = data
815
831
# Remove points from stack if they already exist
816
832
for point in copy(self._stack):
0 commit comments