diff --git a/tiatoolbox/tools/stainaugment.py b/tiatoolbox/tools/stainaugment.py index 154304841..6b310f952 100644 --- a/tiatoolbox/tools/stainaugment.py +++ b/tiatoolbox/tools/stainaugment.py @@ -158,7 +158,10 @@ def fit(self, img, threshold=0.85): img, self.stain_matrix ) self.n_stains = self.source_concentrations.shape[1] - self.tissue_mask = get_luminosity_tissue_mask(img, threshold=threshold).ravel() + if not self.augment_background: + self.tissue_mask = get_luminosity_tissue_mask( + img, threshold=threshold + ).ravel() self.img_shape = img.shape def augment(self):