diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 183264373..4b0f5639d 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata): def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0) / original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S) - 1)) + index = int(torch.searchsorted(cumulative_sums, target)) + index = max(0, min(index, len(S) - 1)) return index @@ -69,8 +69,8 @@ def index_sv_fro(S, target): S_squared = S.pow(2) S_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S) - 1)) + index = int(torch.searchsorted(sum_S_squared, target**2)) + index = max(0, min(index, len(S) - 1)) return index @@ -79,7 +79,7 @@ def index_sv_ratio(S, target): max_sv = S[0] min_sv = max_sv / target index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S) - 1)) + index = max(0, min(index, len(S) - 1)) return index