diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 2bfbcd74..55ad5a2b 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -49,6 +49,7 @@ def parameters(self) -> List[ConfigParameterDef]: ConfigParameterDef( name="rescale", required=False, default_value=self.default_rescale ), + ConfigParameterDef(name="adjusted", required=False, default_value=False), ] def tensor_parameters(self) -> List[ConfigParameterDef]: @@ -73,6 +74,7 @@ def make_task( int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], rescale=parameters["rescale"], + adjusted=parameters["adjusted"], out_tensor_name=output_weight.name, ) @@ -86,6 +88,7 @@ class GTATask(Task[torch.Tensor]): int8_mask: bool normalize: bool rescale: bool + adjusted: bool def uses_accelerator(self) -> bool: return True @@ -116,6 +119,7 @@ def execute( density=tv_info["density"], method=self.method.sparsification_method, rescale=self.rescale, + adjusted=self.adjusted, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 69c923ac..e8b106be 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -59,7 +59,7 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens return res -def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: +def bernoulli(tensor: torch.Tensor, density: float, rescale: bool, adjusted: bool) -> torch.Tensor: if density >= 1: return tensor @@ -69,6 +69,12 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens # torch.bernoulli not implemented for float16 on CPU, upcast to float32 work_dtype = torch.float32 + if adjusted: + s = (tensor.count_nonzero() / tensor.numel()).item() + if density >= s: + return tensor + density /= s + mask = torch.bernoulli( torch.full_like(input=tensor, fill_value=density, dtype=work_dtype) ) @@ -83,10 +89,11 @@ def sparsify( density: float, method: SparsificationMethod, rescale: bool = False, + adjusted: bool = False, ) -> torch.Tensor: if method == SparsificationMethod.magnitude: return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: - return bernoulli(tensor, density=density, rescale=rescale) + return bernoulli(tensor, density=density, rescale=rescale, adjusted=adjusted) else: raise NotImplementedError(method)