Skip to content

Commit 10b564b

Browse files
Refactor Feature chaining and add test case for Feature chain with lambda (#206)
1 parent 35891bd commit 10b564b

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

deeptrack/features.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def __iter__(self):
508508
def __next__(self):
509509
yield self.update().resolve()
510510

511-
def __rshift__(self, other: "Feature") -> "Feature":
511+
def __rshift__(self, other) -> "Feature":
512512

513513
# Allows chaining of features. For example,
514514
# feature1 >> feature2 >> feature3
@@ -519,12 +519,11 @@ def __rshift__(self, other: "Feature") -> "Feature":
519519
return Chain(self, other)
520520

521521
# Import here to avoid circular import.
522-
from . import models
522+
523523

524524
# If other is a function, call it on the output of the feature.
525525
# For example, feature >> some_function
526-
if isinstance(other, models.KerasModel):
527-
return NotImplemented
526+
528527
if callable(other):
529528
return self >> Lambda(lambda: other)
530529

deeptrack/test/test_features.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,18 @@ def test_Feature_arithmetic(self):
458458
input_2 = [10, 20]
459459
self.assertListEqual(pipeline(input_2), [-input_2[0], -input_2[1]])
460460

461+
def test_Features_chain_lambda(self):
462+
463+
value = features.Value(value=1)
464+
func = lambda x: x + 1
465+
466+
feature = value >> func
467+
feature.store_properties()
468+
469+
feature.update()
470+
output_image = feature()
471+
self.assertEqual(output_image, 2)
472+
461473
def test_Feature_repeat(self):
462474
feature = features.Value(value=0) >> (features.Add(1) ^ iter(range(10)))
463475

0 commit comments

Comments
 (0)