From 0860de48cb9897c84f7399c7155437f3ce07dcd5 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 14 Oct 2025 12:51:41 +0000 Subject: [PATCH 01/21] First working version Signed-off-by: simondanielsson --- .../models/language/generation/test_hybrid.py | 10 +- vllm/config/model.py | 2 +- vllm/model_executor/layers/fla/ops/chunk.py | 38 ++- vllm/model_executor/models/config.py | 1 + vllm/model_executor/models/qwen3_next.py | 308 ++++++++++++++++- vllm/v1/attention/backends/gdn_attn.py | 311 ++++++++++++++++++ 6 files changed, 645 insertions(+), 25 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index abedd15b0d7e..235138ec6b36 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -53,7 +53,8 @@ MAX_NUM_SEQS = 4 -@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +# @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("model", ["tiny-random/qwen3-next-moe"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( @@ -77,7 +78,9 @@ def test_models( example_prompts, max_tokens, num_logprobs ) - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner( + model, max_num_seqs=MAX_NUM_SEQS, kv_cache_memory_bytes=1_000_000_000 + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs ) @@ -528,7 +531,8 @@ def test_apc_single_prompt_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +# @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", ["tiny-random/qwen3-next-moe"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version diff --git a/vllm/config/model.py b/vllm/config/model.py index d0c027e47675..905a28666f01 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1488,7 +1488,7 @@ def get_mamba_chunk_size(self) -> Optional[int]: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) - return chunk_size + return chunk_size or 64 def get_multimodal_config(self) -> MultiModalConfig: """ diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index d65c87aba11c..232fa17223c2 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -23,6 +23,23 @@ from .wy_fast import recompute_w_u_fwd +def _reshape_intermediate_states( + states: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> torch.Tensor: + """Return a chunk-major view of the kernel's intermediate states.""" + + if cu_seqlens is None: + # Equal-length batches keep their batch dimension; flatten it together + # with the chunk axis so callers receive a contiguous chunk stream. + return states.reshape(-1, *states.shape[-3:]) + + # Variable-length inputs collapse the batch dimension during preprocessing, + # so the kernel already emits a linearised chunk stream in ``states[:, i]``. + # Flattening mirrors the metadata builder's chunk enumeration order. + return states.reshape(-1, *states.shape[-3:]) + + def chunk_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -66,6 +83,7 @@ def chunk_gated_delta_rule_fwd( scale=scale, cu_seqlens=cu_seqlens, ) + # TODO: perhaps bypass this if return_intermediate_states, if so always return h if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None elif SUPPRESS_LEVEL >= 3: @@ -88,6 +106,7 @@ def forward( output_final_state: bool, cu_seqlens: Optional[torch.LongTensor] = None, use_qk_l2norm_in_kernel: bool = False, + return_intermediate_states: bool = False, ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) @@ -106,7 +125,10 @@ def forward( ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - return o.to(q.dtype), final_state + intermediate_states = None + if return_intermediate_states: + intermediate_states = _reshape_intermediate_states(h, cu_seqlens) + return o.to(q.dtype), final_state, intermediate_states @torch.compiler.disable @@ -122,6 +144,7 @@ def chunk_gated_delta_rule( cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, + return_intermediate_states: bool = False, ): r""" Args: @@ -156,6 +179,10 @@ def chunk_gated_delta_rule( Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + intermediate_states (Optional[torch.Tensor]): + When ``return_intermediate_states`` is ``True`` a tensor containing + the per-chunk state snapshots shaped ``[num_chunks_total, H, K, V]``. + Otherwise ``None``. Examples:: >>> import torch @@ -170,7 +197,7 @@ def chunk_gated_delta_rule( >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') - >>> o, ht = chunk_gated_delta_rule( + >>> o, ht, _ = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True @@ -179,7 +206,7 @@ def chunk_gated_delta_rule( >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_gated_delta_rule( + >>> o_var, ht_var, _ = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, @@ -224,7 +251,7 @@ def chunk_gated_delta_rule( ) if scale is None: scale = k.shape[-1] ** -0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( + o, final_state, intermediate_states = ChunkGatedDeltaRuleFunction.apply( q, k, v, @@ -235,7 +262,8 @@ def chunk_gated_delta_rule( output_final_state, cu_seqlens, use_qk_l2norm_in_kernel, + return_intermediate_states, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") - return o, final_state + return o, final_state, intermediate_states diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ee6a3ba773bb..2fbd36afc726 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -294,6 +294,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Mamba2ForCausalLM", "NemotronHForCausalLM", "Zamba2ForCausalLM", + "Qwen3NextForCausalLM", ] if cache_config.enable_prefix_caching: if model_config.architecture in MAMBA2_MODELS: diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 50629bb2e4a2..487c7561cc93 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -459,6 +459,19 @@ def _forward( spec_token_masks = attn_metadata.spec_token_masks spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + state_indices_tensor_d = attn_metadata.state_indices_tensor_d + state_indices_tensor_p = attn_metadata.state_indices_tensor_p + block_idx_last_computed_token_d = attn_metadata.block_idx_last_computed_token_d + block_idx_last_scheduled_token_d = ( + attn_metadata.block_idx_last_scheduled_token_d + ) + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + block_idx_last_computed_token_p = attn_metadata.block_idx_last_computed_token_p + block_idx_last_scheduled_token_p = ( + attn_metadata.block_idx_last_scheduled_token_p + ) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] @@ -467,6 +480,165 @@ def _forward( if spec_token_masks is not None: spec_token_masks = spec_token_masks[:num_actual_tokens] + prefix_caching_enabled = bool( + ( + state_indices_tensor_d is not None + and block_idx_last_scheduled_token_d is not None + ) + or ( + state_indices_tensor_p is not None + and block_idx_last_scheduled_token_p is not None + ) + ) + non_spec_state_indices_runtime = non_spec_state_indices_tensor + state_indices_decode: Optional[torch.Tensor] = None + state_indices_prefill: Optional[torch.Tensor] = None + + start_non_spec_prefill = attn_metadata.num_decodes + end_non_spec_prefill = start_non_spec_prefill + attn_metadata.num_prefills + + if ( + prefix_caching_enabled + and non_spec_state_indices_tensor is not None + and non_spec_state_indices_tensor.numel() > 0 + ): + non_spec_state_indices_runtime = non_spec_state_indices_tensor.clone() + + num_decodes = attn_metadata.num_decodes + if ( + num_decodes > 0 + and state_indices_tensor_d is not None + and block_idx_last_computed_token_d is not None + and block_idx_last_scheduled_token_d is not None + ): + decode_slice = slice(0, num_decodes) + base_decode_slots = non_spec_state_indices_tensor[decode_slice] + gathered_last_computed = ( + block_idx_last_computed_token_d[:num_decodes] + .clamp(min=0) + .to(torch.long) + ) + gathered_last_scheduled = ( + block_idx_last_scheduled_token_d[:num_decodes] + .clamp(min=0) + .to(torch.long) + ) + slot_in = state_indices_tensor_d.gather( + 1, gathered_last_computed.unsqueeze(1) + ).squeeze(1) + slot_out = state_indices_tensor_d.gather( + 1, gathered_last_scheduled.unsqueeze(1) + ).squeeze(1) + valid_in = (block_idx_last_computed_token_d[:num_decodes] >= 0) & ( + slot_in >= 0 + ) + valid_out = (block_idx_last_scheduled_token_d[:num_decodes] >= 0) & ( + slot_out >= 0 + ) + diff_mask = (slot_in != slot_out) & valid_in & valid_out + diff_positions = torch.nonzero(diff_mask, as_tuple=False).squeeze(-1) + if diff_positions.numel() > 0: + slot_out_masked = slot_out.index_select(0, diff_positions).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_masked = slot_in.index_select(0, diff_positions).to( + device=conv_state.device, dtype=torch.long + ) + conv_state.index_copy_( + 0, + slot_out_masked, + conv_state.index_select(0, slot_in_masked), + ) + ssm_state.index_copy_( + 0, + slot_out_masked, + ssm_state.index_select(0, slot_in_masked), + ) + updated_decode_slots = torch.where( + valid_out, + slot_out, + base_decode_slots, + ) + non_spec_state_indices_runtime[decode_slice] = updated_decode_slots + state_indices_decode = updated_decode_slots + + num_prefills = attn_metadata.num_prefills + if ( + num_prefills > 0 + and state_indices_tensor_p is not None + and block_idx_last_computed_token_p is not None + and block_idx_last_scheduled_token_p is not None + ): + start = attn_metadata.num_decodes + end = start + num_prefills + base_prefill_slots = non_spec_state_indices_tensor[start:end] + gathered_last_computed = ( + block_idx_last_computed_token_p[:num_prefills] + .clamp(min=0) + .to(torch.long) + ) + gathered_last_scheduled = ( + block_idx_last_scheduled_token_p[:num_prefills] + .clamp(min=0) + .to(torch.long) + ) + slot_in = state_indices_tensor_p.gather( + 1, gathered_last_computed.unsqueeze(1) + ).squeeze(1) + slot_out = state_indices_tensor_p.gather( + 1, gathered_last_scheduled.unsqueeze(1) + ).squeeze(1) + valid_in = (block_idx_last_computed_token_p[:num_prefills] >= 0) & ( + slot_in >= 0 + ) + valid_out = (block_idx_last_scheduled_token_p[:num_prefills] >= 0) & ( + slot_out >= 0 + ) + diff_mask = (slot_in != slot_out) & valid_in & valid_out + diff_positions = torch.nonzero(diff_mask, as_tuple=False).squeeze(-1) + if diff_positions.numel() > 0: + slot_out_masked = slot_out.index_select(0, diff_positions).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_masked = slot_in.index_select(0, diff_positions).to( + device=conv_state.device, dtype=torch.long + ) + conv_state.index_copy_( + 0, + slot_out_masked, + conv_state.index_select(0, slot_in_masked), + ) + ssm_state.index_copy_( + 0, + slot_out_masked, + ssm_state.index_select(0, slot_in_masked), + ) + + updated_prefill_slots = torch.where( + valid_out, + slot_out, + base_prefill_slots, + ) + non_spec_state_indices_runtime[start:end] = updated_prefill_slots + state_indices_prefill = updated_prefill_slots + + if state_indices_decode is None and non_spec_state_indices_tensor is not None: + state_indices_decode = non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ] + + if state_indices_prefill is None and non_spec_state_indices_tensor is not None: + state_indices_prefill = non_spec_state_indices_tensor[ + start_non_spec_prefill:end_non_spec_prefill + ] + + if attn_metadata.num_decodes > 0: + assert state_indices_decode is not None + + if attn_metadata.num_prefills > 0: + assert state_indices_prefill is not None + assert non_spec_state_indices_runtime is not None + # 1. Set up dimensions for reshapes later projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) @@ -523,20 +695,19 @@ def _forward( activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, + cache_indices=non_spec_state_indices_runtime, query_start_loc=non_spec_query_start_loc, metadata=attn_metadata, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: + assert state_indices_decode is not None mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes - ], + conv_state_indices=state_indices_decode, validate_data=True, ) else: @@ -571,7 +742,7 @@ def _forward( # 3. Recurrent attention - # 3.1: process the mutlti-query part + # 3.1: process the multi-query part if spec_sequence_masks is not None: core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( q=query_spec, @@ -591,11 +762,33 @@ def _forward( # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 + chunk_state_indices = non_spec_state_indices_runtime[:end_non_spec_prefill] + initial_state = ssm_state.new_zeros( + (chunk_state_indices.shape[0], *ssm_state.shape[1:]) + ) + if chunk_state_indices.numel() > 0: + valid_chunk_slots = chunk_state_indices >= 0 + valid_chunk_positions = torch.nonzero( + valid_chunk_slots, as_tuple=False + ).squeeze(-1) + if valid_chunk_positions.numel() > 0: + initial_state.index_copy_( + 0, + valid_chunk_positions, + ssm_state.index_select( + 0, + chunk_state_indices.index_select( + 0, valid_chunk_positions + ).to(device=ssm_state.device, dtype=torch.long), + ), + ) + if has_initial_state is not None: + chunk_has_initial_state = has_initial_state[:end_non_spec_prefill] + initial_state[~chunk_has_initial_state, ...] = 0 ( core_attn_out_non_spec, last_recurrent_state, + block_state_history, ) = chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, @@ -604,15 +797,101 @@ def _forward( beta=beta_non_spec, initial_state=initial_state, output_final_state=True, - cu_seqlens=non_spec_query_start_loc, + cu_seqlens=non_spec_query_start_loc[: end_non_spec_prefill + 1], head_first=False, use_qk_l2norm_in_kernel=True, + return_intermediate_states=prefix_caching_enabled, ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype - ) + if chunk_state_indices.numel() > 0: + valid_chunk_slots = chunk_state_indices >= 0 + valid_chunk_positions = torch.nonzero( + valid_chunk_slots, as_tuple=False + ).squeeze(-1) + if valid_chunk_positions.numel() > 0: + dest_slots = chunk_state_indices.index_select( + 0, valid_chunk_positions + ).to(device=ssm_state.device, dtype=torch.long) + ssm_state.index_copy_( + 0, + dest_slots, + last_recurrent_state.index_select(0, valid_chunk_positions).to( + ssm_state.dtype + ), + ) + if prefix_caching_enabled: + if ( + block_state_history is not None + and block_state_history.numel() > 0 + and block_idx_first_scheduled_token_p is not None + and block_idx_last_scheduled_token_p is not None + and state_indices_tensor_p is not None + and attn_metadata.last_chunk_indices_p is not None + and attn_metadata.num_computed_tokens_p is not None + and attn_metadata.chunk_size is not None + and attn_metadata.block_size is not None + ): + block_history = block_state_history.to(ssm_state.dtype) + chunk_size = attn_metadata.chunk_size + block_size = attn_metadata.block_size + chunk_stride = block_size // chunk_size + last_chunk_indices = attn_metadata.last_chunk_indices_p + last_chunk_indices_long = last_chunk_indices.to(torch.long) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + + for seq_idx in range(attn_metadata.num_prefills): + block_first = int( + block_idx_first_scheduled_token_p[seq_idx].item() + ) + block_last = int( + block_idx_last_scheduled_token_p[seq_idx].item() + ) + n_blocks_to_fill = block_last - block_first + if n_blocks_to_fill <= 0: + continue + + cache_blocks = state_indices_tensor_p[ + seq_idx, block_first:block_last + ].to(torch.long) + + first_chunk = ( + 0 + if seq_idx == 0 + else int(last_chunk_indices[seq_idx - 1].item()) + 1 + ) + first_aligned_chunk = first_chunk + chunk_stride - 1 + num_unaligned_tokens = int( + num_computed_tokens_p[seq_idx].item() % block_size + ) + if num_unaligned_tokens > 0: + first_aligned_chunk -= num_unaligned_tokens // chunk_size + chunk_stop = ( + first_aligned_chunk + n_blocks_to_fill * chunk_stride + ) + cached_states = block_history[ + first_aligned_chunk:chunk_stop:chunk_stride + ] + ssm_state[cache_blocks] = cached_states + + final_slots = state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + valid_final = final_slots >= 0 + valid_final_positions = torch.nonzero( + valid_final, as_tuple=False + ).squeeze(-1) + if valid_final_positions.numel() > 0: + final_slot_ids = final_slots.index_select( + 0, valid_final_positions + ).to(device=ssm_state.device, dtype=torch.long) + final_states = block_history.index_select( + 0, + last_chunk_indices_long.index_select( + 0, valid_final_positions + ), + ) + ssm_state.index_copy_(0, final_slot_ids, final_states) elif attn_metadata.num_decodes > 0: + assert state_indices_decode is not None core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( q=query_non_spec, @@ -625,7 +904,7 @@ def _forward( cu_seqlens=non_spec_query_start_loc[ : attn_metadata.num_decodes + 1 ], - ssm_state_indices=non_spec_state_indices_tensor, + ssm_state_indices=state_indices_decode, use_qk_l2norm_in_kernel=True, ) ) @@ -1109,9 +1388,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) self.quant_config = vllm_config.quant_config super().__init__() diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 21fc2ab72768..a6f554e8284a 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -37,6 +38,8 @@ class GDNAttentionMetadata: num_actual_tokens: int has_initial_state: Optional[torch.Tensor] = None + block_size: Optional[int] = None + chunk_size: Optional[int] = None spec_query_start_loc: Optional[torch.Tensor] = ( None # shape: [num_spec_decodes + 1,] @@ -55,6 +58,21 @@ class GDNAttentionMetadata: ) num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + # Decode-side APC metadata + state_indices_tensor_d: Optional[torch.Tensor] = None + state_indices_tensor_p: Optional[torch.Tensor] = None + block_idx_last_computed_token_d: Optional[torch.Tensor] = None + block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None + + # Prefill-side APC metadata + block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None + block_idx_last_computed_token_p: Optional[torch.Tensor] = None + block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None + seq_idx_p: Optional[torch.Tensor] = None + cu_chunk_seqlen_p: Optional[torch.Tensor] = None + last_chunk_indices_p: Optional[torch.Tensor] = None + num_computed_tokens_p: Optional[torch.Tensor] = None + # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None batch_ptr: Optional[torch.Tensor] = None @@ -78,6 +96,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.speculative_config = vllm_config.speculative_config self.kv_cache_spec = kv_cache_spec + self.device = device if self.speculative_config: self.num_spec = self.speculative_config.num_speculative_tokens else: @@ -85,6 +104,14 @@ def __init__( self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) + self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() or 64 + if self.vllm_config.cache_config.enable_prefix_caching: + if kv_cache_spec.block_size % self.chunk_size != 0: + raise ValueError( + "GDN prefix caching requires the mamba block size to be a " + "multiple of the kernel chunk size." + ) + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -93,6 +120,10 @@ def __init__( self.compilation_config.max_capture_size, ) + self._max_cached_blocks = cdiv( + vllm_config.model_config.max_model_len, kv_cache_spec.block_size + ) + self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), dtype=torch.int32, @@ -129,6 +160,79 @@ def __init__( device=device, ) + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor_d_buf = torch.empty( + (self.decode_cudagraph_max_bs, self._max_cached_blocks), + dtype=torch.int32, + device=device, + ) + self.state_indices_tensor_p_buf = torch.empty( + (self.decode_cudagraph_max_bs, self._max_cached_blocks), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token_d_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token_d_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + + max_num_prefill_chunks = ( + vllm_config.model_config.max_model_len // kv_cache_spec.block_size + ) * self.decode_cudagraph_max_bs + self.seq_idx_p_buf = torch.empty( + (max_num_prefill_chunks,), + dtype=torch.int32, + device=device, + ) + self.cu_chunk_seqlen_p_buf = torch.empty( + (max_num_prefill_chunks + 1,), + dtype=torch.int32, + device=device, + ) + self.last_chunk_indices_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.num_computed_tokens_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_first_scheduled_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token_p_buf = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + else: + self.state_indices_tensor_d_buf = None + self.block_idx_last_computed_token_d_buf = None + self.block_idx_last_scheduled_token_d_buf = None + self.state_indices_tensor_p_buf = None + self.seq_idx_p_buf = None + self.cu_chunk_seqlen_p_buf = None + self.last_chunk_indices_p_buf = None + self.num_computed_tokens_p_buf = None + self.block_idx_first_scheduled_token_p_buf = None + self.block_idx_last_computed_token_p_buf = None + self.block_idx_last_scheduled_token_p_buf = None + def build( # type: ignore[override] self, common_prefix_len: int, @@ -144,6 +248,25 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + enable_apc = self.vllm_config.cache_config.enable_prefix_caching + block_size_value: Optional[int] = None + chunk_size_value: Optional[int] = None + if enable_apc: + block_size_value = self.kv_cache_spec.block_size + chunk_size_value = self.chunk_size + state_indices_tensor_d: Optional[torch.Tensor] = None + state_indices_tensor_p: Optional[torch.Tensor] = None + block_idx_last_computed_token_d: Optional[torch.Tensor] = None + block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None + block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None + block_idx_last_computed_token_p: Optional[torch.Tensor] = None + block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None + num_computed_tokens_p: Optional[torch.Tensor] = None + seq_idx_p: Optional[torch.Tensor] = None + cu_chunk_seqlen_p: Optional[torch.Tensor] = None + last_chunk_indices_p: Optional[torch.Tensor] = None + non_spec_query_start_loc_cpu: Optional[torch.Tensor] = None + if ( not self.use_spec_decode or num_decode_draft_tokens_cpu is None @@ -174,6 +297,7 @@ def build( # type: ignore[override] non_spec_state_indices_tensor = m.block_table_tensor[:, 0] spec_query_start_loc = None non_spec_query_start_loc = query_start_loc + non_spec_query_start_loc_cpu = m.query_start_loc_cpu num_accepted_tokens = None else: query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -228,6 +352,16 @@ def build( # type: ignore[override] dim=0, out=non_spec_query_start_loc[1:], ) + query_lens_cpu = m.query_start_loc_cpu[1:] - m.query_start_loc_cpu[:-1] + non_spec_query_start_loc_cpu = torch.zeros( + query_lens_cpu.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + ) + torch.cumsum( + query_lens_cpu[~spec_sequence_masks.cpu()], + dim=0, + out=non_spec_query_start_loc_cpu[1:], + ) num_spec_decode_tokens = ( query_lens.sum().item() - num_prefill_tokens - num_decode_tokens @@ -235,6 +369,141 @@ def build( # type: ignore[override] assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + if enable_apc: + block_table_tensor_full = m.block_table_tensor + block_size = self.kv_cache_spec.block_size + num_computed_tokens_device = m.num_computed_tokens_cpu.to( + self.device, dtype=torch.int32 + ) + seq_lens_device = m.seq_lens.to(self.device, dtype=torch.int32) + + block_idx_last_computed_all = ( + (cdiv(num_computed_tokens_device, block_size) - 1) + .clamp(min=0) + .to(torch.int32) + ) + block_idx_first_scheduled_all = ( + cdiv(num_computed_tokens_device + 1, block_size) - 1 + ).to(torch.int32) + block_idx_last_scheduled_all = (cdiv(seq_lens_device, block_size) - 1).to( + torch.int32 + ) + + if spec_sequence_masks is not None: + non_spec_mask = ~spec_sequence_masks + non_spec_block_table = block_table_tensor_full[non_spec_mask] + block_idx_last_computed_non_spec = block_idx_last_computed_all[ + non_spec_mask + ] + block_idx_last_scheduled_non_spec = block_idx_last_scheduled_all[ + non_spec_mask + ] + block_idx_first_scheduled_non_spec = block_idx_first_scheduled_all[ + non_spec_mask + ] + num_computed_tokens_non_spec = num_computed_tokens_device[non_spec_mask] + spec_sequence_masks_cpu = spec_sequence_masks.cpu() + non_spec_mask_cpu = ~spec_sequence_masks_cpu + num_computed_tokens_cpu_non_spec = m.num_computed_tokens_cpu[ + non_spec_mask_cpu + ] + else: + non_spec_block_table = block_table_tensor_full + block_idx_last_computed_non_spec = block_idx_last_computed_all + block_idx_last_scheduled_non_spec = block_idx_last_scheduled_all + block_idx_first_scheduled_non_spec = block_idx_first_scheduled_all + num_computed_tokens_non_spec = num_computed_tokens_device + num_computed_tokens_cpu_non_spec = m.num_computed_tokens_cpu + + if num_decodes > 0: + state_indices_tensor_d = non_spec_block_table[:num_decodes].contiguous() + block_idx_last_computed_token_d = block_idx_last_computed_non_spec[ + :num_decodes + ].contiguous() + block_idx_last_scheduled_token_d = block_idx_last_scheduled_non_spec[ + :num_decodes + ].contiguous() + + if num_prefills > 0: + start = num_decodes + end = start + num_prefills + state_indices_tensor_p = non_spec_block_table[start:end].contiguous() + block_idx_first_scheduled_token_p = block_idx_first_scheduled_non_spec[ + start:end + ].contiguous() + block_idx_last_computed_token_p = block_idx_last_computed_non_spec[ + start:end + ].contiguous() + block_idx_last_scheduled_token_p = block_idx_last_scheduled_non_spec[ + start:end + ].contiguous() + num_computed_tokens_p = num_computed_tokens_non_spec[ + start:end + ].contiguous() + + if spec_sequence_masks is None: + num_computed_tokens_p_cpu = m.num_computed_tokens_cpu[ + m.num_reqs - num_prefills : + ] + query_start_loc_p_cpu = ( + m.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens + ) + else: + num_computed_tokens_p_cpu = num_computed_tokens_cpu_non_spec[ + num_decodes: + ] + query_start_loc_p_cpu = ( + non_spec_query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) + + cu_chunk_seqlen: list[int] = [] + seq_idx_list: list[int] = [] + last_chunk_indices_list: list[int] = [] + seqlen_pos = 0 + + for req_idx in range(num_prefills): + this_num_computed = int(num_computed_tokens_p_cpu[req_idx].item()) + this_new_tokens = int( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) + + if this_num_computed % self.chunk_size != 0: + seq_idx_list.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for _ in range(n_chunks): + seq_idx_list.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices_list.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + device = query_start_loc.device + seq_idx_p = torch.as_tensor( + seq_idx_list, device=device, dtype=torch.int32 + ) + cu_chunk_seqlen_p = torch.as_tensor( + cu_chunk_seqlen, device=device, dtype=torch.int32 + ) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices_list, device=device, dtype=torch.int32 + ) + if num_prefills > 0: has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: @@ -321,6 +590,35 @@ def build( # type: ignore[override] non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + if enable_apc and num_decodes > 0: + assert state_indices_tensor_d is not None + num_blocks = state_indices_tensor_d.shape[1] + self.state_indices_tensor_d_buf[:num_decodes, :num_blocks].copy_( + state_indices_tensor_d, non_blocking=True + ) + state_indices_tensor_d = self.state_indices_tensor_d_buf[ + :batch_size, :num_blocks + ] + state_indices_tensor_d[num_decodes:, :].fill_(PAD_SLOT_ID) + + assert block_idx_last_scheduled_token_d is not None + self.block_idx_last_scheduled_token_d_buf[:num_decodes].copy_( + block_idx_last_scheduled_token_d, non_blocking=True + ) + block_idx_last_scheduled_token_d = ( + self.block_idx_last_scheduled_token_d_buf[:batch_size] + ) + block_idx_last_scheduled_token_d[num_decodes:] = 0 + + assert block_idx_last_computed_token_d is not None + self.block_idx_last_computed_token_d_buf[:num_decodes].copy_( + block_idx_last_computed_token_d, non_blocking=True + ) + block_idx_last_computed_token_d = ( + self.block_idx_last_computed_token_d_buf[:batch_size] + ) + block_idx_last_computed_token_d[num_decodes:] = 0 + attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -330,6 +628,8 @@ def build( # type: ignore[override] num_spec_decode_tokens=num_spec_decode_tokens, num_actual_tokens=num_actual_tokens, has_initial_state=has_initial_state, + block_size=block_size_value, + chunk_size=chunk_size_value, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc, spec_state_indices_tensor=spec_state_indices_tensor, @@ -337,6 +637,17 @@ def build( # type: ignore[override] spec_sequence_masks=spec_sequence_masks, spec_token_masks=spec_token_masks, num_accepted_tokens=num_accepted_tokens, + state_indices_tensor_d=state_indices_tensor_d, + state_indices_tensor_p=state_indices_tensor_p, + block_idx_last_computed_token_d=block_idx_last_computed_token_d, + block_idx_last_scheduled_token_d=block_idx_last_scheduled_token_d, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token_p=block_idx_last_computed_token_p, + block_idx_last_scheduled_token_p=block_idx_last_scheduled_token_p, + seq_idx_p=seq_idx_p, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_indices_p=last_chunk_indices_p, + num_computed_tokens_p=num_computed_tokens_p, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, From 538c9a039df440059b8d89499b2bbcd8c9d8ed2d Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 14 Oct 2025 13:21:03 +0000 Subject: [PATCH 02/21] Update type hints in gdn_attn Signed-off-by: simondanielsson --- vllm/v1/attention/backends/gdn_attn.py | 70 +++++++++++++------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 66a9912e0f42..fd5b6b34960f 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -36,10 +36,9 @@ class GDNAttentionMetadata: num_spec_decode_tokens: int num_actual_tokens: int - has_initial_state: Optional[torch.Tensor] = None - block_size: Optional[int] = None - chunk_size: Optional[int] = None has_initial_state: torch.Tensor | None = None + block_size: int | None = None + chunk_size: int | None = None spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,] non_spec_query_start_loc: torch.Tensor | None = ( @@ -57,19 +56,19 @@ class GDNAttentionMetadata: num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] # Decode-side APC metadata - state_indices_tensor_d: Optional[torch.Tensor] = None - state_indices_tensor_p: Optional[torch.Tensor] = None - block_idx_last_computed_token_d: Optional[torch.Tensor] = None - block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None + state_indices_tensor_d: torch.Tensor | None = None + state_indices_tensor_p: torch.Tensor | None = None + block_idx_last_computed_token_d: torch.Tensor | None = None + block_idx_last_scheduled_token_d: torch.Tensor | None = None # Prefill-side APC metadata - block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None - block_idx_last_computed_token_p: Optional[torch.Tensor] = None - block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None - seq_idx_p: Optional[torch.Tensor] = None - cu_chunk_seqlen_p: Optional[torch.Tensor] = None - last_chunk_indices_p: Optional[torch.Tensor] = None - num_computed_tokens_p: Optional[torch.Tensor] = None + block_idx_first_scheduled_token_p: torch.Tensor | None = None + block_idx_last_computed_token_p: torch.Tensor | None = None + block_idx_last_scheduled_token_p: torch.Tensor | None = None + seq_idx_p: torch.Tensor | None = None + cu_chunk_seqlen_p: torch.Tensor | None = None + last_chunk_indices_p: torch.Tensor | None = None + num_computed_tokens_p: torch.Tensor | None = None # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None @@ -103,12 +102,13 @@ def __init__( self._init_reorder_batch_threshold(1, self.use_spec_decode) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() or 64 - if self.vllm_config.cache_config.enable_prefix_caching: - if kv_cache_spec.block_size % self.chunk_size != 0: - raise ValueError( - "GDN prefix caching requires the mamba block size to be a " - "multiple of the kernel chunk size." - ) + if self.vllm_config.cache_config.enable_prefix_caching and ( + kv_cache_spec.block_size % self.chunk_size != 0 + ): + raise ValueError( + "GDN prefix caching requires the mamba block size to be a " + "multiple of the kernel chunk size." + ) self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() @@ -247,23 +247,23 @@ def build( # type: ignore[override] nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None enable_apc = self.vllm_config.cache_config.enable_prefix_caching - block_size_value: Optional[int] = None - chunk_size_value: Optional[int] = None + block_size_value: int | None = None + chunk_size_value: int | None = None if enable_apc: block_size_value = self.kv_cache_spec.block_size chunk_size_value = self.chunk_size - state_indices_tensor_d: Optional[torch.Tensor] = None - state_indices_tensor_p: Optional[torch.Tensor] = None - block_idx_last_computed_token_d: Optional[torch.Tensor] = None - block_idx_last_scheduled_token_d: Optional[torch.Tensor] = None - block_idx_first_scheduled_token_p: Optional[torch.Tensor] = None - block_idx_last_computed_token_p: Optional[torch.Tensor] = None - block_idx_last_scheduled_token_p: Optional[torch.Tensor] = None - num_computed_tokens_p: Optional[torch.Tensor] = None - seq_idx_p: Optional[torch.Tensor] = None - cu_chunk_seqlen_p: Optional[torch.Tensor] = None - last_chunk_indices_p: Optional[torch.Tensor] = None - non_spec_query_start_loc_cpu: Optional[torch.Tensor] = None + state_indices_tensor_d: torch.Tensor | None = None + state_indices_tensor_p: torch.Tensor | None = None + block_idx_last_computed_token_d: torch.Tensor | None = None + block_idx_last_scheduled_token_d: torch.Tensor | None = None + block_idx_first_scheduled_token_p: torch.Tensor | None = None + block_idx_last_computed_token_p: torch.Tensor | None = None + block_idx_last_scheduled_token_p: torch.Tensor | None = None + num_computed_tokens_p: torch.Tensor | None = None + seq_idx_p: torch.Tensor | None = None + cu_chunk_seqlen_p: torch.Tensor | None = None + last_chunk_indices_p: torch.Tensor | None = None + non_spec_query_start_loc_cpu: torch.Tensor | None = None if ( not self.use_spec_decode @@ -277,7 +277,7 @@ def build( # type: ignore[override] num_spec_decodes = 0 else: spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 - num_spec_decodes = spec_sequence_masks.sum().item() + num_spec_decodes = int(spec_sequence_masks.sum().item()) if num_spec_decodes == 0: spec_sequence_masks = None else: From 3fffae04cfd609da534a7d3b08606e953fde0c32 Mon Sep 17 00:00:00 2001 From: Jaya Yuan Date: Tue, 14 Oct 2025 21:07:50 +0800 Subject: [PATCH 03/21] [DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttention (#24864) Signed-off-by: yuanyongjie.yyj Signed-off-by: FENP <32334296+FENP@users.noreply.github.com> Signed-off-by: Jaya Yuan --- tests/distributed/test_context_parallel.py | 6 +- tests/models/registry.py | 5 +- vllm/attention/ops/common.py | 10 +- vllm/config/model.py | 17 ++ vllm/v1/attention/backends/flash_attn.py | 202 ++++++++++++++++++--- vllm/v1/attention/backends/utils.py | 1 + vllm/v1/worker/gpu_model_runner.py | 1 + 7 files changed, 209 insertions(+), 33 deletions(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 149b502a85a7..5495640af07e 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -204,17 +204,21 @@ def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = { - # [MLA attention only] "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), ], + "bigcode/gpt_bigcode-santacoder": [ + CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2), + ], } CP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "deepseek-ai/DeepSeek-V2-Lite-Chat", + "bigcode/gpt_bigcode-santacoder", ] diff --git a/tests/models/registry.py b/tests/models/registry.py index c6dbae3a5347..617dc30691aa 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -262,7 +262,10 @@ def check_available_online( "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo( "bigcode/starcoder", - extras={"tiny": "bigcode/tiny_starcoder_py"}, + extras={ + "tiny": "bigcode/tiny_starcoder_py", + "santacoder": "bigcode/gpt_bigcode-santacoder", + }, min_transformers_version="4.55.1", transformers_version_reason="HF model broken in 4.55.0", ), diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 1234e1b2e46a..b6b7ecd2552a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -173,6 +173,7 @@ def cp_lse_ag_out_rs( cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext = None, + return_lse=False, ): """ cp_attn_out: [ B, H, D ] @@ -192,8 +193,15 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) + + if return_lse: + cp_num_heads = lse.shape[1] // cp_group.world_size + cp_rank = cp_group.rank_in_group + lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)] + return out, lse return out diff --git a/vllm/config/model.py b/vllm/config/model.py index ea5de7ca3e88..4356924e4232 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1202,6 +1202,23 @@ def verify_with_parallel_config( "Supported models implement the `SupportsPP` interface." ) + decode_context_parallel_size = parallel_config.decode_context_parallel_size + if decode_context_parallel_size > 1 and not self.use_mla: + total_num_kv_heads = self.get_total_num_kv_heads() + assert tensor_parallel_size > total_num_kv_heads, ( + f"tensor parallel size {tensor_parallel_size} must be greater " + f"than total num kv heads {total_num_kv_heads} when enable " + f"decode context parallel for GQA/MQA" + ) + + max_dcp_size = tensor_parallel_size // total_num_kv_heads + assert decode_context_parallel_size <= max_dcp_size, ( + f"decode context parallel size must less than or equal to " + f"(tensor parallel size {tensor_parallel_size} // total " + f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, " + f"but got {decode_context_parallel_size}" + ) + def get_sliding_window(self) -> int | None: """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fb5ff499de2c..fa4e34536135 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -17,6 +17,7 @@ is_quantized_kv_cache, ) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, @@ -32,6 +33,7 @@ ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -147,6 +149,10 @@ class FlashAttentionMetadata: prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None + # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None @@ -216,6 +222,16 @@ def __init__( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = get_flash_attn_version() == 3 + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -306,7 +322,7 @@ def schedule( batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q, + num_heads_q=self.num_heads_q * self.dcp_world_size, num_heads_kv=self.num_heads_kv, headdim=self.headdim, cache_seqlens=seqlens, @@ -320,8 +336,35 @@ def schedule( return None use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( + self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() - if use_cascade: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: cu_prefix_query_lens = torch.tensor( [0, num_actual_tokens], dtype=torch.int32, device=self.device ) @@ -348,10 +391,6 @@ def schedule( causal=True, ) else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None scheduler_metadata = schedule( batch_size=num_reqs, cu_query_lens=query_start_loc, @@ -379,6 +418,8 @@ def schedule( seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, @@ -396,6 +437,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -562,30 +605,45 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) - return output + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output # Cascade attention (rare case). cascade_attention( @@ -615,6 +673,86 @@ def forward( ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -684,6 +822,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: """Decide whether to use cascade attention. @@ -705,6 +844,9 @@ def use_cascade_attention( num_reqs = len(query_lens) if num_reqs < 8: return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index beb267f196fb..cb5855548098 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -345,6 +345,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: return False diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c2893bd0926..ce174664710b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1523,6 +1523,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, num_sms=self.num_sms, + dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 From 76ac0fa581f5f586d84571bae3ee79ff6e27c41b Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 14 Oct 2025 13:38:25 +0000 Subject: [PATCH 04/21] Enable cudagraphs support [skip ci] Signed-off-by: simondanielsson --- vllm/model_executor/layers/fla/ops/chunk.py | 2 +- vllm/model_executor/models/qwen3_next.py | 208 ++++++++++---------- 2 files changed, 108 insertions(+), 102 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 598878133e30..0e1b485772de 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -24,7 +24,7 @@ def _reshape_intermediate_states( states: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor], + cu_seqlens: torch.LongTensor | None, ) -> torch.Tensor: """Return a chunk-major view of the kernel's intermediate states.""" diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index a1282cdc02e3..cf371c9c0c28 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -490,8 +490,8 @@ def _forward( ) ) non_spec_state_indices_runtime = non_spec_state_indices_tensor - state_indices_decode: Optional[torch.Tensor] = None - state_indices_prefill: Optional[torch.Tensor] = None + state_indices_decode: torch.Tensor | None = None + state_indices_prefill: torch.Tensor | None = None start_non_spec_prefill = attn_metadata.num_decodes end_non_spec_prefill = start_non_spec_prefill + attn_metadata.num_prefills @@ -534,24 +534,32 @@ def _forward( valid_out = (block_idx_last_scheduled_token_d[:num_decodes] >= 0) & ( slot_out >= 0 ) - diff_mask = (slot_in != slot_out) & valid_in & valid_out - diff_positions = torch.nonzero(diff_mask, as_tuple=False).squeeze(-1) - if diff_positions.numel() > 0: - slot_out_masked = slot_out.index_select(0, diff_positions).to( - device=conv_state.device, dtype=torch.long - ) - slot_in_masked = slot_in.index_select(0, diff_positions).to( - device=conv_state.device, dtype=torch.long - ) + slot_out_safe = torch.where( + valid_out, + slot_out, + base_decode_slots, + ) + slot_in_safe = torch.where( + valid_in, + slot_in, + slot_out_safe, + ) + slot_out_copy = slot_out_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_copy = slot_in_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, - slot_out_masked, - conv_state.index_select(0, slot_in_masked), + slot_out_copy, + conv_state.index_select(0, slot_in_copy), ) ssm_state.index_copy_( 0, - slot_out_masked, - ssm_state.index_select(0, slot_in_masked), + slot_out_copy, + ssm_state.index_select(0, slot_in_copy), ) updated_decode_slots = torch.where( valid_out, @@ -593,24 +601,32 @@ def _forward( valid_out = (block_idx_last_scheduled_token_p[:num_prefills] >= 0) & ( slot_out >= 0 ) - diff_mask = (slot_in != slot_out) & valid_in & valid_out - diff_positions = torch.nonzero(diff_mask, as_tuple=False).squeeze(-1) - if diff_positions.numel() > 0: - slot_out_masked = slot_out.index_select(0, diff_positions).to( - device=conv_state.device, dtype=torch.long - ) - slot_in_masked = slot_in.index_select(0, diff_positions).to( - device=conv_state.device, dtype=torch.long - ) + slot_out_safe = torch.where( + valid_out, + slot_out, + base_prefill_slots, + ) + slot_in_safe = torch.where( + valid_in, + slot_in, + slot_out_safe, + ) + slot_out_copy = slot_out_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + slot_in_copy = slot_in_safe.clamp(min=0).to( + device=conv_state.device, dtype=torch.long + ) + if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, - slot_out_masked, - conv_state.index_select(0, slot_in_masked), + slot_out_copy, + conv_state.index_select(0, slot_in_copy), ) ssm_state.index_copy_( 0, - slot_out_masked, - ssm_state.index_select(0, slot_in_masked), + slot_out_copy, + ssm_state.index_select(0, slot_in_copy), ) updated_prefill_slots = torch.where( @@ -817,78 +833,69 @@ def _forward( ssm_state.dtype ), ) - if prefix_caching_enabled: - if ( - block_state_history is not None - and block_state_history.numel() > 0 - and block_idx_first_scheduled_token_p is not None - and block_idx_last_scheduled_token_p is not None - and state_indices_tensor_p is not None - and attn_metadata.last_chunk_indices_p is not None - and attn_metadata.num_computed_tokens_p is not None - and attn_metadata.chunk_size is not None - and attn_metadata.block_size is not None - ): - block_history = block_state_history.to(ssm_state.dtype) - chunk_size = attn_metadata.chunk_size - block_size = attn_metadata.block_size - chunk_stride = block_size // chunk_size - last_chunk_indices = attn_metadata.last_chunk_indices_p - last_chunk_indices_long = last_chunk_indices.to(torch.long) - num_computed_tokens_p = attn_metadata.num_computed_tokens_p - - for seq_idx in range(attn_metadata.num_prefills): - block_first = int( - block_idx_first_scheduled_token_p[seq_idx].item() - ) - block_last = int( - block_idx_last_scheduled_token_p[seq_idx].item() - ) - n_blocks_to_fill = block_last - block_first - if n_blocks_to_fill <= 0: - continue - - cache_blocks = state_indices_tensor_p[ - seq_idx, block_first:block_last - ].to(torch.long) - - first_chunk = ( - 0 - if seq_idx == 0 - else int(last_chunk_indices[seq_idx - 1].item()) + 1 - ) - first_aligned_chunk = first_chunk + chunk_stride - 1 - num_unaligned_tokens = int( - num_computed_tokens_p[seq_idx].item() % block_size - ) - if num_unaligned_tokens > 0: - first_aligned_chunk -= num_unaligned_tokens // chunk_size - chunk_stop = ( - first_aligned_chunk + n_blocks_to_fill * chunk_stride - ) - cached_states = block_history[ - first_aligned_chunk:chunk_stop:chunk_stride - ] - ssm_state[cache_blocks] = cached_states - - final_slots = state_indices_tensor_p.gather( - 1, block_idx_last_scheduled_token_p.unsqueeze(1) - ).squeeze(1) - valid_final = final_slots >= 0 - valid_final_positions = torch.nonzero( - valid_final, as_tuple=False - ).squeeze(-1) - if valid_final_positions.numel() > 0: - final_slot_ids = final_slots.index_select( - 0, valid_final_positions - ).to(device=ssm_state.device, dtype=torch.long) - final_states = block_history.index_select( - 0, - last_chunk_indices_long.index_select( - 0, valid_final_positions - ), - ) - ssm_state.index_copy_(0, final_slot_ids, final_states) + if prefix_caching_enabled and ( + block_state_history is not None + and block_state_history.numel() > 0 + and block_idx_first_scheduled_token_p is not None + and block_idx_last_scheduled_token_p is not None + and state_indices_tensor_p is not None + and attn_metadata.last_chunk_indices_p is not None + and attn_metadata.num_computed_tokens_p is not None + and attn_metadata.chunk_size is not None + and attn_metadata.block_size is not None + ): + block_history = block_state_history.to(ssm_state.dtype) + chunk_size = attn_metadata.chunk_size + block_size = attn_metadata.block_size + chunk_stride = block_size // chunk_size + last_chunk_indices = attn_metadata.last_chunk_indices_p + last_chunk_indices_long = last_chunk_indices.to(torch.long()) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + + for seq_idx in range(attn_metadata.num_prefills): + block_first = int(block_idx_first_scheduled_token_p[seq_idx].item()) + block_last = int(block_idx_last_scheduled_token_p[seq_idx].item()) + n_blocks_to_fill = block_last - block_first + if n_blocks_to_fill <= 0: + continue + + cache_blocks = state_indices_tensor_p[ + seq_idx, block_first:block_last + ].to(torch.long) + + first_chunk = ( + 0 + if seq_idx == 0 + else int(last_chunk_indices[seq_idx - 1].item()) + 1 + ) + first_aligned_chunk = first_chunk + chunk_stride - 1 + num_unaligned_tokens = int( + num_computed_tokens_p[seq_idx].item() % block_size + ) + if num_unaligned_tokens > 0: + first_aligned_chunk -= num_unaligned_tokens // chunk_size + chunk_stop = first_aligned_chunk + n_blocks_to_fill * chunk_stride + cached_states = block_history[ + first_aligned_chunk:chunk_stop:chunk_stride + ] + ssm_state[cache_blocks] = cached_states + + final_slots = state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + valid_final = final_slots >= 0 + valid_final_positions = torch.nonzero( + valid_final, as_tuple=False + ).squeeze(-1) + if valid_final_positions.numel() > 0: + final_slot_ids = final_slots.index_select( + 0, valid_final_positions + ).to(device=ssm_state.device, dtype=torch.long) + final_states = block_history.index_select( + 0, + last_chunk_indices_long.index_select(0, valid_final_positions), + ) + ssm_state.index_copy_(0, final_slot_ids, final_states) elif attn_metadata.num_decodes > 0: assert state_indices_decode is not None core_attn_out_non_spec, last_recurrent_state = ( @@ -1384,7 +1391,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config From 795ed51e8aab0c46ec0452432121e951fa7ff556 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 14 Oct 2025 13:44:28 +0000 Subject: [PATCH 05/21] Fix long() -> long [skip ci] Signed-off-by: simondanielsson --- vllm/model_executor/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index cf371c9c0c28..f0ef40d9d8a5 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -849,7 +849,7 @@ def _forward( block_size = attn_metadata.block_size chunk_stride = block_size // chunk_size last_chunk_indices = attn_metadata.last_chunk_indices_p - last_chunk_indices_long = last_chunk_indices.to(torch.long()) + last_chunk_indices_long = last_chunk_indices.to(torch.long) num_computed_tokens_p = attn_metadata.num_computed_tokens_p for seq_idx in range(attn_metadata.num_prefills): From 044990c17768a57b6c958e62f4d7d8cab7a870a7 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 14 Oct 2025 18:50:25 +0000 Subject: [PATCH 06/21] Add defensive programming asserts Signed-off-by: simondanielsson --- vllm/model_executor/layers/fla/ops/chunk.py | 8 +++++--- vllm/model_executor/models/qwen3_next.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 0e1b485772de..65d90406fc9d 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -124,9 +124,11 @@ def forward( ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - intermediate_states = None - if return_intermediate_states: - intermediate_states = _reshape_intermediate_states(h, cu_seqlens) + intermediate_states = ( + _reshape_intermediate_states(h, cu_seqlens) + if return_intermediate_states + else None + ) return o.to(q.dtype), final_state, intermediate_states diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f0ef40d9d8a5..f26e11d5fab6 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -550,6 +550,7 @@ def _forward( slot_in_copy = slot_in_safe.clamp(min=0).to( device=conv_state.device, dtype=torch.long ) + breakpoint() if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, @@ -566,6 +567,7 @@ def _forward( slot_out, base_decode_slots, ) + breakpoint() non_spec_state_indices_runtime[decode_slice] = updated_decode_slots state_indices_decode = updated_decode_slots @@ -617,6 +619,7 @@ def _forward( slot_in_copy = slot_in_safe.clamp(min=0).to( device=conv_state.device, dtype=torch.long ) + breakpoint() if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, @@ -634,9 +637,11 @@ def _forward( slot_out, base_prefill_slots, ) + breakpoint() non_spec_state_indices_runtime[start:end] = updated_prefill_slots state_indices_prefill = updated_prefill_slots + breakpoint() if state_indices_decode is None and non_spec_state_indices_tensor is not None: state_indices_decode = non_spec_state_indices_tensor[ : attn_metadata.num_decodes @@ -800,6 +805,13 @@ def _forward( if has_initial_state is not None: chunk_has_initial_state = has_initial_state[:end_non_spec_prefill] initial_state[~chunk_has_initial_state, ...] = 0 + + assert query_non_spec is not None + assert key_non_spec is not None + assert value_non_spec is not None + assert g_non_spec is not None + assert beta_non_spec is not None + cu_seqlens = non_spec_query_start_loc[: end_non_spec_prefill + 1] ( core_attn_out_non_spec, last_recurrent_state, @@ -812,7 +824,7 @@ def _forward( beta=beta_non_spec, initial_state=initial_state, output_final_state=True, - cu_seqlens=non_spec_query_start_loc[: end_non_spec_prefill + 1], + cu_seqlens=cu_seqlens, head_first=False, use_qk_l2norm_in_kernel=True, return_intermediate_states=prefix_caching_enabled, From 68ca70f8e0582ec0eeb1fbdb0ea24bde46f188f2 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 10:05:05 +0000 Subject: [PATCH 07/21] Allocate metadata buffer by chunk count rather than block count, and make sure prefill block-history indexing captures decode chunks Signed-off-by: simondanielsson --- vllm/model_executor/models/qwen3_next.py | 113 +++++++++++++---------- vllm/v1/attention/backends/gdn_attn.py | 6 +- 2 files changed, 66 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f26e11d5fab6..ee64ee6e3743 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -857,57 +857,68 @@ def _forward( and attn_metadata.block_size is not None ): block_history = block_state_history.to(ssm_state.dtype) - chunk_size = attn_metadata.chunk_size - block_size = attn_metadata.block_size - chunk_stride = block_size // chunk_size - last_chunk_indices = attn_metadata.last_chunk_indices_p - last_chunk_indices_long = last_chunk_indices.to(torch.long) - num_computed_tokens_p = attn_metadata.num_computed_tokens_p - - for seq_idx in range(attn_metadata.num_prefills): - block_first = int(block_idx_first_scheduled_token_p[seq_idx].item()) - block_last = int(block_idx_last_scheduled_token_p[seq_idx].item()) - n_blocks_to_fill = block_last - block_first - if n_blocks_to_fill <= 0: - continue - - cache_blocks = state_indices_tensor_p[ - seq_idx, block_first:block_last - ].to(torch.long) - - first_chunk = ( - 0 - if seq_idx == 0 - else int(last_chunk_indices[seq_idx - 1].item()) + 1 - ) - first_aligned_chunk = first_chunk + chunk_stride - 1 - num_unaligned_tokens = int( - num_computed_tokens_p[seq_idx].item() % block_size - ) - if num_unaligned_tokens > 0: - first_aligned_chunk -= num_unaligned_tokens // chunk_size - chunk_stop = first_aligned_chunk + n_blocks_to_fill * chunk_stride - cached_states = block_history[ - first_aligned_chunk:chunk_stop:chunk_stride - ] - ssm_state[cache_blocks] = cached_states - - final_slots = state_indices_tensor_p.gather( - 1, block_idx_last_scheduled_token_p.unsqueeze(1) - ).squeeze(1) - valid_final = final_slots >= 0 - valid_final_positions = torch.nonzero( - valid_final, as_tuple=False - ).squeeze(-1) - if valid_final_positions.numel() > 0: - final_slot_ids = final_slots.index_select( - 0, valid_final_positions - ).to(device=ssm_state.device, dtype=torch.long) - final_states = block_history.index_select( - 0, - last_chunk_indices_long.index_select(0, valid_final_positions), - ) - ssm_state.index_copy_(0, final_slot_ids, final_states) + chunk_offset = int(attn_metadata.num_decodes) + block_history_prefill = block_history[chunk_offset:] + if block_history_prefill.shape[0] > 0: + chunk_size = attn_metadata.chunk_size + block_size = attn_metadata.block_size + chunk_stride = block_size // chunk_size + last_chunk_indices = attn_metadata.last_chunk_indices_p + last_chunk_indices_long = last_chunk_indices.to(torch.long) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + + for seq_idx in range(attn_metadata.num_prefills): + block_first = int( + block_idx_first_scheduled_token_p[seq_idx].item() + ) + block_last = int( + block_idx_last_scheduled_token_p[seq_idx].item() + ) + n_blocks_to_fill = block_last - block_first + if n_blocks_to_fill <= 0: + continue + + cache_blocks = state_indices_tensor_p[ + seq_idx, block_first:block_last + ].to(torch.long) + + first_chunk = ( + 0 + if seq_idx == 0 + else int(last_chunk_indices[seq_idx - 1].item()) + 1 + ) + first_aligned_chunk = first_chunk + chunk_stride - 1 + num_unaligned_tokens = int( + num_computed_tokens_p[seq_idx].item() % block_size + ) + if num_unaligned_tokens > 0: + first_aligned_chunk -= num_unaligned_tokens // chunk_size + chunk_stop = ( + first_aligned_chunk + n_blocks_to_fill * chunk_stride + ) + cached_states = block_history_prefill[ + first_aligned_chunk:chunk_stop:chunk_stride + ] + ssm_state[cache_blocks] = cached_states + + final_slots = state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + valid_final = final_slots >= 0 + valid_final_positions = torch.nonzero( + valid_final, as_tuple=False + ).squeeze(-1) + if valid_final_positions.numel() > 0: + final_slot_ids = final_slots.index_select( + 0, valid_final_positions + ).to(device=ssm_state.device, dtype=torch.long) + final_states = block_history_prefill.index_select( + 0, + last_chunk_indices_long.index_select( + 0, valid_final_positions + ), + ) + ssm_state.index_copy_(0, final_slot_ids, final_states) elif attn_metadata.num_decodes > 0: assert state_indices_decode is not None core_attn_out_non_spec, last_recurrent_state = ( diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index fd5b6b34960f..1fe7c98680d1 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -181,8 +181,9 @@ def __init__( ) max_num_prefill_chunks = ( - vllm_config.model_config.max_model_len // kv_cache_spec.block_size - ) * self.decode_cudagraph_max_bs + cdiv(vllm_config.model_config.max_model_len, self.chunk_size) + * self.decode_cudagraph_max_bs + ) self.seq_idx_p_buf = torch.empty( (max_num_prefill_chunks,), dtype=torch.int32, @@ -450,6 +451,7 @@ def build( # type: ignore[override] num_computed_tokens_p_cpu = num_computed_tokens_cpu_non_spec[ num_decodes: ] + assert non_spec_query_start_loc_cpu is not None query_start_loc_p_cpu = ( non_spec_query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens From fe8f0b7e7180f626a8b25115bc879cf913bd1def Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 13:12:21 +0000 Subject: [PATCH 08/21] Return hidden state when return_intermediate_states is passed, ignoring GDN_RECOMPUTE_SUPPRESS_LEVEL Signed-off-by: simondanielsson --- vllm/model_executor/layers/fla/ops/chunk.py | 22 ++++++++++++++------- vllm/model_executor/models/qwen3_next.py | 5 ----- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 65d90406fc9d..1904e4722c49 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -49,6 +49,7 @@ def chunk_gated_delta_rule_fwd( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, + return_intermediate_states: bool = False, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. @@ -82,9 +83,16 @@ def chunk_gated_delta_rule_fwd( scale=scale, cu_seqlens=cu_seqlens, ) - # TODO: perhaps bypass this if return_intermediate_states, if so always return h if SUPPRESS_LEVEL < 3: - return g, o, A, final_state, None, None, None + return ( + g, + o, + A, + final_state, + None, + h if return_intermediate_states else None, + None, + ) elif SUPPRESS_LEVEL >= 3: return g, o, A, final_state, w, h, v_new @@ -121,14 +129,14 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + return_intermediate_states=return_intermediate_states, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - intermediate_states = ( - _reshape_intermediate_states(h, cu_seqlens) - if return_intermediate_states - else None - ) + intermediate_states = None + if return_intermediate_states: + assert h is not None + intermediate_states = _reshape_intermediate_states(h, cu_seqlens) return o.to(q.dtype), final_state, intermediate_states diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ee64ee6e3743..59aabe023595 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -550,7 +550,6 @@ def _forward( slot_in_copy = slot_in_safe.clamp(min=0).to( device=conv_state.device, dtype=torch.long ) - breakpoint() if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, @@ -567,7 +566,6 @@ def _forward( slot_out, base_decode_slots, ) - breakpoint() non_spec_state_indices_runtime[decode_slice] = updated_decode_slots state_indices_decode = updated_decode_slots @@ -619,7 +617,6 @@ def _forward( slot_in_copy = slot_in_safe.clamp(min=0).to( device=conv_state.device, dtype=torch.long ) - breakpoint() if slot_out_copy.numel() > 0: conv_state.index_copy_( 0, @@ -637,11 +634,9 @@ def _forward( slot_out, base_prefill_slots, ) - breakpoint() non_spec_state_indices_runtime[start:end] = updated_prefill_slots state_indices_prefill = updated_prefill_slots - breakpoint() if state_indices_decode is None and non_spec_state_indices_tensor is not None: state_indices_decode = non_spec_state_indices_tensor[ : attn_metadata.num_decodes From ac226e84b1be17fae78130974a766684bc600c4c Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 13:18:21 +0000 Subject: [PATCH 09/21] Inline _reshape_intermediate_states in the fla chunk kernel wrapper Signed-off-by: simondanielsson --- vllm/model_executor/layers/fla/ops/chunk.py | 25 ++++++--------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 1904e4722c49..1a312affcf14 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -22,23 +22,6 @@ from .wy_fast import recompute_w_u_fwd -def _reshape_intermediate_states( - states: torch.Tensor, - cu_seqlens: torch.LongTensor | None, -) -> torch.Tensor: - """Return a chunk-major view of the kernel's intermediate states.""" - - if cu_seqlens is None: - # Equal-length batches keep their batch dimension; flatten it together - # with the chunk axis so callers receive a contiguous chunk stream. - return states.reshape(-1, *states.shape[-3:]) - - # Variable-length inputs collapse the batch dimension during preprocessing, - # so the kernel already emits a linearised chunk stream in ``states[:, i]``. - # Flattening mirrors the metadata builder's chunk enumeration order. - return states.reshape(-1, *states.shape[-3:]) - - def chunk_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -136,7 +119,13 @@ def forward( intermediate_states = None if return_intermediate_states: assert h is not None - intermediate_states = _reshape_intermediate_states(h, cu_seqlens) + # Convert intermediate states into "chunk-major" form + # Equal-length batches keep their batch dimension; flatten it together + # with the chunk axis so callers receive a contiguous chunk stream. + # Variable-length inputs collapse the batch dimension during preprocessing, + # so the kernel already emits a linearised chunk stream in ``states[:, i]``. + # Flattening mirrors the metadata builder's chunk enumeration order. + intermediate_states = h.reshape(-1, *h.shape[-3:]) return o.to(q.dtype), final_state, intermediate_states From f9752605d9895988556cc6891e003bcf34fe5781 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 13:38:32 +0000 Subject: [PATCH 10/21] Add more explanatory comments in FLA's chunk.py Signed-off-by: simondanielsson --- vllm/model_executor/layers/fla/ops/chunk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 1a312affcf14..21893e1ab5c1 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -125,6 +125,7 @@ def forward( # Variable-length inputs collapse the batch dimension during preprocessing, # so the kernel already emits a linearised chunk stream in ``states[:, i]``. # Flattening mirrors the metadata builder's chunk enumeration order. + # Last three axes of h are [H, K, V], producing [num_chunks_total, H, K, V] intermediate_states = h.reshape(-1, *h.shape[-3:]) return o.to(q.dtype), final_state, intermediate_states From e74f67d4c089bb05c26f6035ea7ad5e1365adc8a Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 14:36:11 +0000 Subject: [PATCH 11/21] Improve logging Signed-off-by: simondanielsson --- vllm/model_executor/models/config.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 537e45f63024..8aa4c3409336 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -304,14 +304,20 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Mamba2ForCausalLM", "NemotronHForCausalLM", "Zamba2ForCausalLM", + ] + GDN_MODELS = [ "Qwen3NextForCausalLM", ] if cache_config.enable_prefix_caching: - if model_config.architecture in MAMBA2_MODELS: + if model_config.architecture in MAMBA2_MODELS + GDN_MODELS: + layer_type = ( + "Mamba2" if model_config.architecture in MAMBA2_MODELS else "GDN" + ) logger.info( "Warning: Prefix caching is currently enabled. " - "Its support for Mamba2 layers is experimental. " - "Please report any issues you may observe." + "Its support for %s layers is experimental. " + "Please report any issues you may observe.", + layer_type, ) else: logger.info( From f177a1f86d572316a1ef9e1a14085b8d7293b38e Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 14:44:38 +0000 Subject: [PATCH 12/21] Add GDN model to APC tests Signed-off-by: simondanielsson --- .../models/language/generation/test_hybrid.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index c0e732b8b739..1d4a119f1ff1 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -27,6 +27,8 @@ # "yujiepan/mamba2-codestral-v0.1-tiny-random", ] +GDN_MODELS = ["tiny-random/qwen3-next-moe"] + HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", "pfnet/plamo-2-1b", @@ -35,8 +37,7 @@ "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", "LiquidAI/LFM2-1.2B", - "tiny-random/qwen3-next-moe", -] +] + GDN_MODELS FULL_CUDA_GRAPH_MODELS = [ "ai21labs/Jamba-tiny-dev", @@ -53,8 +54,7 @@ MAX_NUM_SEQS = 4 -# @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) -@pytest.mark.parametrize("model", ["tiny-random/qwen3-next-moe"]) +@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( @@ -383,7 +383,7 @@ def _get_vLLM_output( return outs, vllm_model -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -449,7 +449,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -531,8 +531,7 @@ def test_apc_single_prompt_block_align_alignment( ) -# @pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) -@pytest.mark.parametrize("model", ["tiny-random/qwen3-next-moe"]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -599,7 +598,7 @@ def test_apc_multiple_prompts_all_cached_outputs( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -683,7 +682,7 @@ def test_apc_multiple_prompts_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version From 552ba6f43f38007739ab5216c68de627dd14104d Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 14:45:45 +0000 Subject: [PATCH 13/21] Add helpful comments in hard-to-understand areas Signed-off-by: simondanielsson --- vllm/model_executor/models/qwen3_next.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 59aabe023595..ce9b7936fd53 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -501,6 +501,8 @@ def _forward( and non_spec_state_indices_tensor is not None and non_spec_state_indices_tensor.numel() > 0 ): + # Work on a copy so that updates to the runtime view don't leak back + # into the attention metadata shared across microbatches. non_spec_state_indices_runtime = non_spec_state_indices_tensor.clone() num_decodes = attn_metadata.num_decodes @@ -551,6 +553,8 @@ def _forward( device=conv_state.device, dtype=torch.long ) if slot_out_copy.numel() > 0: + # Recycle the previously computed state into the newly + # scheduled slot so we can skip recomputing the prefix. conv_state.index_copy_( 0, slot_out_copy, @@ -618,6 +622,8 @@ def _forward( device=conv_state.device, dtype=torch.long ) if slot_out_copy.numel() > 0: + # Mirror the decode path: move cached prefix states into + # the slots assigned to this prefill chunk. conv_state.index_copy_( 0, slot_out_copy, @@ -797,6 +803,7 @@ def _forward( ).to(device=ssm_state.device, dtype=torch.long), ), ) + if has_initial_state is not None: chunk_has_initial_state = has_initial_state[:end_non_spec_prefill] initial_state[~chunk_has_initial_state, ...] = 0 @@ -855,6 +862,9 @@ def _forward( chunk_offset = int(attn_metadata.num_decodes) block_history_prefill = block_history[chunk_offset:] if block_history_prefill.shape[0] > 0: + # The block history contains recurrent states per chunk; we + # replay it into the persistent cache blocks owned by each + # sequence so future steps can hit the prefix cache. chunk_size = attn_metadata.chunk_size block_size = attn_metadata.block_size chunk_stride = block_size // chunk_size From 2ab062d1107d758e4d7eb5b8f7023c38ec61296d Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 14:58:48 +0000 Subject: [PATCH 14/21] Improve way to set chunk_size=64 for GDN Signed-off-by: simondanielsson --- vllm/config/model.py | 7 ++++++- vllm/v1/attention/backends/gdn_attn.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 4c59aeef947e..ccb44b411a53 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1505,7 +1505,12 @@ def get_mamba_chunk_size(self) -> int | None: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) - return chunk_size or 64 + if chunk_size is None and self.hf_text_config.model_type == "qwen3_next": + # Fallback for Qwen3-Next. 64 is a hardcoded value in the GDN kernel. + # https://github.com/fla-org/flash-linear-attention/blob/2e7336262c11f8bc6cd6a94b1eb5ee353ae8b4cd/fla/ops/common/chunk_delta_h.py#L439 + return 64 + + return chunk_size def get_multimodal_config(self) -> MultiModalConfig: """ diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 8f8aad9e8896..91977f2997bf 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -101,7 +101,7 @@ def __init__( self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) - self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() or 64 + self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() if self.vllm_config.cache_config.enable_prefix_caching and ( kv_cache_spec.block_size % self.chunk_size != 0 ): From 4837a11fff423b9887a2eab78b9ca6d4b0044986 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 14:59:48 +0000 Subject: [PATCH 15/21] Revert KV cache memory limit in test Signed-off-by: simondanielsson --- tests/models/language/generation/test_hybrid.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 1d4a119f1ff1..e616442a1440 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -78,9 +78,7 @@ def test_models( example_prompts, max_tokens, num_logprobs ) - with vllm_runner( - model, max_num_seqs=MAX_NUM_SEQS, kv_cache_memory_bytes=1_000_000_000 - ) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs ) From b58362a4d182f81c6476b7c6f18dd56d9c5a274c Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 16 Oct 2025 16:08:15 +0000 Subject: [PATCH 16/21] Add dynamic counting of decode chunks, rather than static value Signed-off-by: simondanielsson --- vllm/model_executor/models/qwen3_next.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 76a12521bc1a..89b2a9316430 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -858,8 +858,18 @@ def _forward( and attn_metadata.block_size is not None ): block_history = block_state_history.to(ssm_state.dtype) - chunk_offset = int(attn_metadata.num_decodes) - block_history_prefill = block_history[chunk_offset:] + total_chunks = block_history.shape[0] + last_chunk_indices = attn_metadata.last_chunk_indices_p + prefill_chunk_count = ( + int(last_chunk_indices[-1].item()) + 1 + if last_chunk_indices is not None and last_chunk_indices.numel() > 0 + else 0 + ) + decode_chunk_count = max(total_chunks - prefill_chunk_count, 0) + # Prefill chunks trail the decode chunks; skip the actual number of + # decode chunk completions so partial decodes (no chunk output) do + # not offset the history. + block_history_prefill = block_history[decode_chunk_count:] if block_history_prefill.shape[0] > 0: # The block history contains recurrent states per chunk; we # replay it into the persistent cache blocks owned by each From ccda04e809ec3616dcde15030c7511372b9908dc Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 17 Oct 2025 09:57:15 +0000 Subject: [PATCH 17/21] Add plot Signed-off-by: simondanielsson --- latency_plot.png | Bin 0 -> 61629 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 latency_plot.png diff --git a/latency_plot.png b/latency_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..b77e138b2839e600c04f3c8376937aa609aa4e6e GIT binary patch literal 61629 zcmeFad0fqT|2}>&#tehI8`;Yck(NnB+H5gtwQDz&P((%hGGnYEPL$S@rLwfBR9cj= zwUv%KH0<@vgv*L6Lg=Zw1QhVf%( zjAbww;}w5jx0%5hYsX-W+Wymz_!sfO*U9+LN^^ys=37kmnp^x~YQRwW!~C$ZiMjFq z-E)o^n40Z3IV2$}{+sA>;W_)v%@3Qc6cal*^bAoGQ$sQ3(0N<&E;X?H22E)u^&#KFR))=wf+M+@tZ0??>;-@FSSj3Zg2K` zx>Iv<@BNKV3pH7r_2zEfH2>ubk6iWUz%^&~ZM}2m2Oj2WB?0}I_Vx|-(JcR;&$!I% z_VG&G6BO$ao>|<`yE0#=JvzHJ%_e`Tq@*O>+Q7qcZ%|~_h{2y2jGasT)P{a?9P8Wi z^UzP~BQDDgKU2stTX5)Ct46DL{V?>?EyhOH&<~6$tEwjr{p34xq20)#pSCjo-}#z3I*aw|j5RzJ2?WKD>J(-|F@HP^?If zzlxUXq$_H6c6J#n%FcNmc|LvM=GC7KjF#iss9lK}^XuX|yG#1*uQhWX|D7%sqE`+$V^!r;$_!XXeg!j6OGut+K+E2kX)YZl5=O< z-Ip(4mS2#r?|5mZCS7+QKk3&+KaAvm&{_uCkLAjC)qq0z5;mUHf&ST*ofF`hjuiGv4Z~8yhsE4ta0ep|cEkbx`zSOLm&OgyE%{AIHzQ z9FeDax4Ko+DCi2_*FOX!D`J{nTNrMZJl{<}wk}f1O;jg4=4ee|eAoLzvrJk1uiduv z+_*g3@=F_**mbq^)EdewmBibHXK!3=kh^V1`k~l5S3@7Gy=Pdh@!j>QQSq%oE%k4t z7YE(nywalHYrJH^oJ2kwV;}pjYX4^Z(uAfmwl*v^AhjpV9`%>xUM zDexxdXsW7o-DQF3N)IEs#<=ItpKIIHMJO^GY}wMO29bW^mLJ~IkU2Z_?Ae=kH8iPv z9AB?}(yyn4*~yhQy{-ApCH=i6(N))-n$jdP%X`~P)MAg-Ejjq=zD!%e9;vRX8{tnT zNSEJmmeZ_?$?%pAzcf|KsuB#&M0$JNw!bSR44P1h7Dq z@HR~u(hW3ONusf}OuR+c7DMl)`!`D#Tb9uzyZ`pbngd2g68Ny9qxm{m6bDU9?Ak@z zVlWGP@warkwV5Vy;vorr)|wFquADr1GOi_CX}Y9w!J2lt_qT2^1spn4^_seBLL5qE zc4!|_m@#u^czsNbXkTrvO;zIlGy5@JM_$$>Mx^mVu$l}(gOCK}N(=hZle%P9KA}v46Aa&N^MoEht<~-FiyuJRV1ps(2#Y|y0*ys3US7N88%yX7uVr&++wzksSb+TEkD?4&=t;K75cE30o} zPG)w#Jw9=>o*rw_7XR{z4jpTMoG`1vyw7*_@up{D_TuBmkMDfbU>{YWQJLnRv$a6B zDQTYDdzlVl<*{z5x?~jcP%^e5zI#oqSulPvUB>e8p@iNxfn)HT}k^Q^7SE4$%C&16jvUry*f#C#jj6JO;lrg9nI5lm$s9u?cr8*_-yuOi5w3mzgRqBCEw7 zTeYdq#6dctGu?zYrarc$_YkwQ|9oY-XH=4~MbVKBm<{jy`*RkI;kP}!sDzhpe&-(_ zgSN`6$b+6H_RXG(y#1%!9mI}H2I$C5EwE{Lm=n4uQKqk>jFt&gAFCm{+TU}tme%V) zrQ@E)VTyX#>7njaLpD=KZ0nv8gmW`t&~lwN;Ve$KVErgPX(%oDd@?bb-gIBybKy$y} zm8KvQyT#|vpI3>s!;U=fE@ig&SkaNhK*W+08Ov?jYPGGa{R)pH&dzyg(@w*)Y}qmu zJr_YWRc2ggr6sB_ic>i)ycU_bovwyLr=x>N2MGpPFiZZWe({22FVC}t? zH;%~Nu3;1ukzG)573=BPwRYaxzUc1y*nlkos-u;o^sy`UG4(q>yxWgcO2|QUh3@5K z7ePZ870KkR4gB8S4GHot3sh)J#0-Qjk-I^Wng@*7{d$quO&w`939WfrDh@s?j~<|f zpw@F>-@f0OXvK-VI1 zr1YtO87BbEm=RWNU1RXVcQwmOhwfx&*2k4J?)L55`>oBp6XZ*KF?R>8dybje-tQji zc`#BVzQs?h;{5XBq-Rt0yaIfzg}kLLGO+D7RcQ-)9bTi0nV{o4x+!I`#QyA?wk=s3 z+?N<^EJo0Z>ur0j61xuw$G)mf$xJ-U6j4t=@#OfKu4^Yq3DAmI!^X#n zVX8zr`TFCF1!W_9oA1Bv^DDOb=MZa$u9d-z88bYN^%~>Im>}2jHY}Zky{;$^#O<9= zUm1mX|2b;Rg3g3Da#MR6fDuYsPB&D(moDj4)r>TDFAO&nT-29Uzr|lUIg#JNOQ3My zjq}(}*lStl84j8s3fjwEEJg&~y_@P`6znb#-JBuqW>Ca5j;?S`HE%G%?yG)rZS4Z( zs}~b`D#W6Dno?Z^u(>U;ST;vW(Ru63Tr_~+oQ~2d;-2t?U6pmi^H`7e;W#O+gAut+ z{e30Dqe&$$0N(z?Mo0J)nnf|Ji#pSJng;*|on>Wa6j@vt3Kr zY}JV^0~mnU2mw{8CNHR&j-(H7|50R$*;Ok||83W!e}{NW7^?cKdUxH{!$M#9`lypx zns{P44#Cv@0p4xya_xqVIPpY}6k9b{8XyQ~?1*o=B9MGW(B~?hH8>m;nF;;94TXr; zO+{vzesFgo)JF@BnQg8fiG5*ezc}p8vDXz^C0*N@oh98aYI0I0@tu2fy0Jvkr|i(q+B!kU zYqYba!~5F-OH7L60;7tpr!U>7uE8d(?YPl@0GtcwO^yF>q85jRB9rZ7BdX|eY_~R( zHjYm0vEG*<&+3dd-XB$v0*G;2(V!&u5D@9@U*obdqRL(?#4*roUqu+XO_;q>pTg~v ziSk~lVhKH4`1i#7{>v&$VJZ!WYKZQ`%{zAFZ@sU~q$9b#*v2HXV7F6Te_yvsv)#LK zdcP-LnB9Cr%=TB{G^Jn!&NLs{*xC~3lHPiYLOLQ|ytrtVAtfkl)2584=G@*AU|#oX z*@jt1*7afI*p;^G^)$9DT|^%JRTRGyukD3w!7P{n$wp{ zAh1%180VdjFfP;hbW(I*S4}vervCY5c}YOL5(xEx;|A?-EHXz=)xL;d1=g}cF_6&+ z^UZAeekowe9WfRSJ~*>ih-gKo{W#P3B4CiHo6fPqfZ)b-6MG58Y=0HcS+*bkF<=cu z6AhI;cx0`mdt$6&*5a}|n_2x{z?ky+Iu<%vWt?AxUmuv!RpGxKDtZtYp0N^fJZX=! ztZPtIaZJ@EdXle3qooD?g)z1rCt!Ki4QJK0Yu8Ga*qWW-WNtgpxhhQLJYiL97puRM zMJOocwMl|E;zNKzsw++YS-x@f6p#7R1KU~sDyE~dHD6B++ehN?+eby1vBKR>B7;x# z^8lWHe*F~ZjXD4KHmC4^A+cMPdVS0-UPMQ0K=Mnv53Q%ja?mR0BR)u3Vnya)q0+3e zQoD|IX{*xnq|7QO%l^oCacFyGO zbgC8QMdh$f!*-p@M4B^IW;mn$U{#eLQo@dj_eNstai+^Xe2cuW5W}2uXOHk=C_cY!TP7{oBtO9H8&!$0AOAj%s;0#c50v0k9}R_4oJ3 zeSCcEy1md&QUd2vDnO%Wbxwd`TWz6&^Qe1|jgaVU*D$z3XW;&<>k0#FhhFwb@Xr4G zSISMZLi7?iu~WX$3(>3s)0&C%ATI9t$hZCZkq2M5yw6o?98Ag`nIbAX4r>%thflY{ zvZuHefLvJv!JI8uu?_!G);FfNPbD^aV@N@{MZhOnl!Arikb0aG~2`PY! z0HF$C)}0@gG8V_bdiBb`B+BGr>vn%NR&a~rwdZQymjrItWXHc%WZYW+#kBcO#=ES; zZrPjZ<1MhiMEB|_VYPAg4H-VN(g0|{4T~)2-pSn)qN-Rb*2b8;785ji`VXrv(Tp+I zGA=(3f8u!tIP()04kU z&u)pq#NpU6?CW96vrC=CGTaS)G>Q&0gu^eHycATFtS!9#LW*%q^{aV^Kp43k92qG| z2Nts&$(YGXIs2Xt42SU6f9bphN>@}-Q6YLkA%h?J%z9@&Z|`HT7A+2mGo>_s?L^t7 zIRR>7R?m^FYMUm{FT@q++}PmkpP0>Wn)EA$WckJzoST*Rb5xaZ!enD{QuZ8;?LO6} z9{KL2c^6j9r6U87Is5L85U<0Jf40JWD*4Rx{pm(=VCJpmHBMZI{#wh=i%)M4(iCZ) zLD_BOs{?lhk_n{}Cg`LC`}jn;2j$BQvCj5fkXj&Nn2{D7Tkk=6mZwR)Rq(^MSHH|! z@f#L;{q7JDC1p*>mPAWBo-a4W>>2cQG!_Dco8l(K%!}pw8l8W^p#-dWl+MLj$fc_D zL-a%+y#0&Pv(An(XC)sFoTt_2D}^Y2YZp*Obl3a4E&^JIqAQk!v={7gHYm9F2ZuZm z0BTL6tep{?DEHSqPOP<~`V~na5rqbo=AU$%>X7HNEnfO*ht8 z05~fn(PIh_UhZmdVA1^{j@VGAz@mc#LNo}3evgx!CgSh`hhttq5!_sIO+3;CDeR$U9GFBwWB{Uj6l+<$ zIXcxQz4Vyq>=lQ1A+s&fe|c+DAt2$B-A_*oBo~D4kpK~>kYR#U`}w))TbhdK^Er=n zka$}!Vx@{LFFL%IZWfu37a&e_N2!wp$fhEse5rBvrzXtimC6gN=+T>mc)oalI-vf1 zN53s4b!YiTI0`NJGVEaOPq~FzkhN|}0{&b=2_^CoA}p=~Jugr0*>FcNxhUFf zYV@(yf0$ZLUunJ%!Py5?km;p2(pYWJDXnsPcqrxaA`OIY-j$M8P2~dpJ5+V<#a4r5 z^c*!|g>TV~QNk9-^Kc0K#Mmm1Q*x93KoI5}n`8adZXITOTj*j^eY~vS-v$L_2Ts6K z)qCZ9MF8L+g$##xC#JgA*p@v%pNa?@^ZUDIZ;N6Ybx%J2ewGhjL^~xM!05_P8*lUcWiI@$u(ldgU*mvm^1efLQ zoCnsex~c0HBek}5zeID_eRl=t5v#7?h@GtP;{Yh4GuKbO>+dO#+>o@}mdz|lLgwJ* zc_?O)qA8_U*vcvZr=rp2Od0Gw!sE@jyIs*4#CuoZlTs+DefQ)P`}s_hv<05VOf12$ zYVj5)31p44a7+@B#0k9E${B#c9z=;>YmV=DCbk!Bk`&_d5`blG+s0?nxT7W5gSHJe zC0g~Yse>E!ZS&@w2fMUZ$?fRvYE@!_Y_tw-SiW?re%e)W#Y*|$*vRetQ+HW4N?)zq z-H}t~A~0rr;9~O$j5{u~M*1#XGH{dwR^HM-Dj$8fiiassmRT54#j}e?>HS#I#X%gB zjPqe%3MQJoviSGrgD-Au`PSi1sB*l`F(K<872(G`9?96I{1rM<{U1p)FMwkhjH@P|M`L)ozAur>zV5dG>Z*7c zw|0L8Ciogu!zD>lfl7gsGxdH3@t$-P?*|o!QOagy+=8|r!?a;thXuDFH^6lbgDZ@ZOMXN z$Nm2zZp18oS$;WA%OsJgA|}%PI*A~w5C4o2YqLzekbbY@`DP-8nF2d>j&4H^*Zfz& z9|~;t&_UM`x7E11x$U({rsIfFgpN*J&;yX-00boJGQ<~XMrwCCTS8X3VztrZIERx9 z1K61&^K14p3u(0g<-qtZ*B5_1tDq^HqMlE2#IS-p!agWG@Sd z9&r%>Ke-+A>A`0uYmjoe2dFcXSIhWPuud$JVo41gxMtupDlv7s>u~a}Ai;oW8}r%%@w(ShY6mW;p0CUuhb@O*jv zCBV#i#Fx|O(riQ(nJQ7JEV3xEBT)zxj9ZSHyetrL z9ym{DAB`u$6VlG0n(jV?Q`f)L5M%0@A={iLVfJc&fZqmZJ}FEMfU{Q-Nb^DnS7y(q z>?nV+cH+U6?KEC}9f|ULLH}|{GZK_Y*8~qsE1G@NMO9PZ6~e(@qRF7bnAV0SR#peF zGAi8ke5UGJ{;DR|S-vKrCZYRHLuFX1sFE|Etm)f7$M}Pe=>!$yPN)}xP~RG}!6>=P zPvhN2R=&2G3c#8Rgc89T&?ts6Z8(Y20V}Y_Jt>*ukT}Qcx};Q6MiTQ%((5eq7puqB zwJo>p_(0r^lUF~e#|5N`*mvcyDxnD~Bi&(^BUec!_hBJQr%R?1=>l3!_Nk9shZ%W%GPs=3z&oO*g1)!oR)5po#b&;l)PTZeBvZA2(<@1uS z7Lsa8_e$h|B9mahq4&}4nvp#T9q!9bbdNzvB>70H^6DPmMSJDIfOSIpx&-Pmy1L*m zQ#PR(BC`~Uy)MU?r{aUX!38a^14p>h;nn^9;~oy?296uIHu_~k?Ol#x$T&A$qKU#S z4M}_b8#6FB%cop%&BMQyozP_=!jiPg`^xW-B65JJf-=Mq6_%WxT#ig`DJ>T0pJvae z8w#8E5TQ*vk7(11YOue{tBW7nkh<41JIm!nH18+zY?yV4*xO9GZexhz_g{2FmeAc7m91}I z-3F;X4J=E^)~+~+)J%aVC&z2ueA}s5;;`>fTUE(JE@z@<-+O!09a;1Ad!%AtiJ*O)xgxA3nH4 zWZEJKM=A{}9!qaORa<9P$rd4+0h)c}i3{UrEZui4tb$e@M$GJ*S&Rj%f9eGS(|c~L zKfjSl0~xz09)}m2|phVS0BvJ7iDr72vo1u-dx4 z{$W8oQm)7sHMk)dWh&`KbETK{#bVnDH^viJJ7L9JNAsFI5v)J6JZ-b|`v+f3Dw<;9 z>O&G{z1yO`s3$I6%uckvxyPouS)P9_>t^+(T2h!+wt=t}P}~lkL?J`(u<<`qil&GM zYz&JUxg$WO2N>lCld4M@K9QoKC8uz3inf{2CsLLaXVsj2U4cWp+vjg?d~cYcS}NAT zg$BD`s>B*+4q1_s?s{+-#AQGfI6%54B~kN3tUzN_7|M5RP^?7w5(|@WO$8WZI4^lP_<6nb;0APX8Dep*_ z`f6WaPu%?DV=+x|Ra^j4gJ8I`Jw$A@D={);t zK5MfE86+q%2Vuw&x%LNvgak7wcxHp0gAai-WsjwcK;!9ESz<9#o_XTJ(W)1})&$2M zT7Q1V6)>EZ?)B10EHG?`uOaTpbe6j~$tFOWaT2-w5Q&VET_)c21YdP}6Pq0rM>LCo z6>=zEs|FkadR1N+QOf7QlGLm@ky-ANFfq9&vjTFtw{`l)B?1! z*vZQQi^;I5Jh}7TM281Ss;6=HRya~!-^YVMOKZ-QPf&L^&c3Ng5t7%Y55c7A=_ED% zq6bG*68)9kncq0qpp!l32#U^76h(P37+LrH5 z?9UI@iElh_BH-RH`#Z0QWgjx+bO4*SGwoL;4{8$qb>c!Vfk3bg&A=eB)de69-32IX zox12uJit||cG2iT_;fw!H)o4}3z$ImF_+e5Yuo{4Z-_jNKBud#PU}L}&U3^o!U>T{ zA%)yPuZ_c1o+FVbqD&9beyUYq8{xbS{B}#-v`rQ36Z-ooEOb7e8h;TBY+j~Jvoi4y zenJ+rabk3mNuo9sl6GVuIfClo0aNJM@r4EE62&hs3AnF2Xgb1|2n%xbdkjpA?Y8a4J$`(gv z22QvubmK(Kf$hvSFD|xe@TOdWB>6lYi#N4^!!(@FFUZ?HJ|prF4j~_Ly0`S3!g(z zqryf@$bKZgnL&lDq=OdLfU6LIZrZfh7h>n|b~W_vS6%V$ZzvU&*-B#RWE&va(##%)6@O6VGY{V1WZ&e%(g&p1s44=ZCrMly$cYO* zv?H+C+;INNfW+u~nr#T!WBSkNspSrRY==E%pvJ!SuAWF!Y$q)PqIbx3!l9}u zd;$0nyd%*iFLNkEOx60RE!u9U9FlBL2+H$m`e`ANtVc)VO+^<JZbilA;ZtdG1>#V~Tv&;#kEJ~&d=)YF`4SYH%PiZh)bf%mgFrY%UJ>!LjmLvk?(Djav-$dtOfGPQ~0xf~>f+F&Eit zC`pvNqOEIl7b!AhMC-=+{NQ-tD+1wGvpC)yWo`YfC)UtxH%@0{5<-%Ee)GJm7zr4c*%#P%!cE7{tNR$HL!+hvI6!D$- ze%^J0+=78H4^_X4xclbHq9UJ}!5ZISj5GK8llF*J8jiq9$0rUk1?L~`J-&1McH@cx z;TyLsbTZpe!vEY4j9i2AnXQUUV$irSc8rAz!~`>bD}hWT@jkulTQuzd#FI?YHG{D$ zUf28el<#n2)9&j=UEjd)J^R(dvy77-oIS>XgL>ZrzUdND&ga~pb9w`vQ&v|-RVz11SAS7P5}C>n#I^EDD{K7{n-0T8$RX3 zZ;8+~3_e^xy0qnM5bVq=IB2A{{zZrKfeMO#=t7Gu-T|35S*Y0b;i5Pm&^|^2B z>${a$Y?gyPaYkgpc9*d-m4o+{V}FMj+f1w=B*;EPO~CfzR1_1KzCIXLL82p()zC>7 zuRZyTA}Am^*2EEv4<>Czhts)-h^@6{ZyXw(^vaQ$E5bWR4s?BFp(-HFiCNZvTRN$W_f#~Ff*&gs{R@UhCTa!$pI^7%y-jfx{CBSoM{KYaVX|kod;nn7;i;udKirU{VzRRr?>94;~r&%B3nB5 zQ!Xd~(W~{}nYeHpvY|mnS_zBFt2GS~0A3`F!~g6?M7){%l%WCIwLkE%Unap_9Hz<0 zdWlaoJi|_aysO9r3bhPv5L+R4;N`8R@x$x)!}JL4rbX$m0kS6wy8R%a1Qj7)hZeL?oyCG)u{k{l&Yfp`{4i0&LLI)Kb_YV38;wF%MdC?YE%R^8= zT?8`a`+Mm5w#UayF3d4!Fr3=IE5rs4XPq|z?UiIKAY7}0k)Rwz*+)j|O-ye|ULnCE ztW$Sh5luq+okf+S8NW001JZ@L9w!P4^(~+c)GCpJK?i-ZCXo$J-AOcNAu-h3VLY!z zlBCiw9u|G%_UXxjthg7;i~>W~x0N{b!#^oaD$NpXKmT@n_?Qi$#*+hbs?hP_xfM2< z5}m^aS>!=JW1SfDWF+=sK=dS8TY(o_0+Te6RFob-mfk-%|ygLCWz-*%ZwNr*k5 zhyeL`5$xa%n?v$3J7hMY?1fxi$QhP3B=p-;^%4C$UePy8ekh}1JYsGuS7niolp)t$ zn`xf@u^dayHUR_^rCo?|G3RoIqx)q&1y=<7?d$%>5N4BB%kT{{gxvuIR!JOz%o+_G z!yPped|M^a*tfxqJ1Y7d&hHuk_d=R))ar$-|2`*gn2(T858tr^rg$ZV&;k0#cPC%N z5D|m}k&ed~%EsZU^a?IGl3u=a(ovmP-+=;yHz(vc;C1{WIosEH=W`3i9M}{na+8y7 zGt4%FuV?%>Fw8*l#lS0t2GO!?m&LcS*D8su+d4Je?{Z+cH^HC@(o>Y{$#P;~))Jje zDlSYd5&GtH25;uSfh<4PEpT4Bbr3L)NFZu5OGgMsjVDnw>Lp!j5zHLDyv_AM4nE$1daJ*+GEr0(_ z?;Y;uoOQSyEPC8g0L|6l>8&r7C34B`}GFZw7PkMyzoYE+xV6B zV<07Hm-pDhJEG?WV~N@1>>n$YnK{4^i=+nN9%$ibKi8I48! zC76qD8pel3awp zKn{O@W2o6q-reS#42IIT?JL*kDzQ-U_F=b0g*L=94*G=xnntD#vdtf@$s4Ww6XQc+ z?$P~vTu1NhtLtMHM)mb}3o7nS_qYO%E$Z$5p(|#j+?Y$Y05G^pOZs^}(2=K{PjcII zfIXTq6MS%X?+wMDgv37Oy<8flj%vy*ku&$&-;|9!ZToa-O zDpcD%OF^pT#acJ(^LSs8whofM!OUbX9jFYT>v}YD@yoI|gP{q@|p@2!GtcA!Rh7}A!#PfzoZrPLIP4_Pt{9ww4=6$p$>He^q#d4*i& z_P2G;q{P@vFn6TF@L5+~BaIakq?H}h8ZhP^44&EKVX|#X9ydDykgFL63=z}5Tbph) zQ8fw}G|J-K9>_5sh23FENw-QEPURXQu49ePk~H3V2=6bTNTs3=8|+L%F=WM7!~in& z0VyfY>fTc}9AxC~o)xy6Fnx&tGVDdE%gAl>=+9B964QpJ_!0Z&f8O!nBMmbf zEH?Kj{59j2{VcCJ`F+TKbghg^Uc<@WOx2KN@Ffd*Ma=xc1;SXByHKmTIX|KA z=mgo87xxe3-l;BtD}v)`r`UjF{4#mnsj%4xwe6;Fcp1A+d_SqW9MsZDqBKTJ8Fnv^ z3p$_@+~w@7RZ*@-o^(#!1BXL$wszEL)K10qXL->d?)wdID1`(V>Y2^8JqX?=blCp7 z_T)_*>78Bq2|h@BWk?W?KEx@B#HfCq0#lTXGtqb1kNd zrmtR9ivtY?g4~ua2#85kh0D=Fo8{0$;l4>pJYDkOu(vI_PhoI?)3YCf&k-=9nR&Cp-5qyGtOFfOw4rvNgbuJU7fG}SOHZL1+%-@Vk zaR8cN^SZDo{~$8sX|DrwU|Hs7&BaBuz0X}tehw-tBK3yUL#+R)E7>Qe>LvlExT9(y zYGAlueiH$OU6FIS30W1=Vv=CI1Dtx500hEpk>Q1UXU#`F20XeUyw=>)hfSu{W|$^a z;a4tsV+$XmJF2Re1P#B9`uL5acHp~~fl{bl@}`HnLXdAq`4h%$>v5(M3xh(%-<2a$ zLWg(LqqD&kVcI=dW@J^=@vj+v9^nb7N@H@fyzHt5;L$4$2sq`AXB5r6a zjbv=Sa{R}J;cf?tnd1RmzIIl7VlaA$?FwYA@J3{40!`f%b%!H`5WRt2zdD*kw)K%iVu zzEe>bI;25NBI<(8h`=Z8NkTq!bbp{W3i?E&lBKOD&4`>Vo&QXZCiDH)n6w;{kiJDB~IwJgtXC#&zJ{Ql~#|nOh#luL^NDq$b zLXh632*V_f7NQfFisDZR9BH^I1gmWzePSl zV=$=LMw*Id!_?vBR5r$9S8YKk*;$&YU&!nl1Tg9;g3*=Zl#QS;!`6S97FcwY+%F-s z2%%z5=m9cD2bbsiH_l4}{l_tpbJ^d{mxdO?@#Vl(XhGI82+86oLMbsLe=iG-j_AfS zm+}vd$*O2OQNK*=c8*aV>+d3rILl!<>-VJMKy1`Bdi$RzEhq;?nwB5m79rDKWJWGl zxtuQQR6;(CN+f}$YZCL(b%Uxw6qzWhn;&kNsavWEl%GP0WnPdL8;7SN2FJ5l8gG#U zhZ%JVV(n+E4_%I<>0+bmW-0`t3nj1Hm|2Hc(V&orj+{rR4|K|lu>j>0!q1ihQ zY~w2ao^@4{GpMF=C@U00<-kt^YcZ97XK;!d;k{v5>yRlOGv{h_jy$3$^LhmG*y?l5 zf~vbX%`*T~!qR_Y+*|Q&<_CMl7WU&XG|yq*x^W>T3MMClDYt{Aky60p{Q9!dIZ~Ju znY0ZdF=>=`Qi~!c9RyBx4*xV_)`W9=;lLcW{JI_(CVL(Gw?<~O>2w`2>POqk8uaA3 z3PHiu;3V`IX|jH2qt%rc^eLsI>7W>{3b1x=J_x%cr_Btt>v-!z5LGb@(B?$Oh*d|w zqRoI(!Y@BIlp;og=yMu7axJ4mN=*O4CmHCyN3aRFbU4W}o?ZNsdn0^AmA7rkQmE4n z%QAxs^D(p1&a!PgqYrrskWrqx8BvZ%S{SF)>2?Sqn6Jc;RJnyi>uuGUl*k+^J$OxB zU#K7yh1=to7yX9CqeoV%gOP>)!tH1+qDkgW%F9p9GX0$#(3$3TU+E*MlAO1`XU@ro z^^>|425cqwPKVXCNTIBetZuBBd3#Zq$eCCa1TRXJMP~nmx|*btPnd3ty6K>hLoR2= zIKEN+a07|8^^Q(l4)dhJ%;A!!UDGcg49KLa6RKhTcV-QuJMkI@@g?6#e_2+XkPTax z6{s5t)D^(l>ePk1`PQ8*=Y0cW^UR~H`ZGREvvcbfS= z%kg4J8^_%1mtiB-!slmIb?p=>-&fTGhJBv zI9-rd9DcjxjS-6Nq7kTFIHF8@eEY*^K1ih5P(wr&nJ{L701$Uyx<7j^3SLln3)JC( zz#>IT5LLoqVX~?0mfCJnYpB$vVvI-G)9(8?=#N}=q+I>VfeEo&Ub+AKj?bqR=bNyb zuASoZKvet<5{(u$n=%edUNG)qKN1wq8Ai=%FifI353!aX`pb(gbP&ye#^SM&#q!6+ zQOO9E7?7Wm&LA=@Q$4d(**OU+G%@fPDy;(*cO|GL8qQ@lRVd&nx=bJ$I}0ir=oNVU z)1AcPsm}o&{bY#@S8ksH?-@Q|bDj~EiBolD)r)JXOJ_1x-6sFb8He1TSBZZ`5g-;R zGf}YoD`-P8QOP*Naw@Ei5?DVklVM@}aaiEQAU zl4G240*X`tiWo$RK}Fm!;f*ZE2!ban$qf+b;Q@Wi4wV_+ySXT<8v2RBa19KUw#)EX zwm4tU!4BXwl#Is6A2e&NQ0qjMU#V8ZAk}e-CeY0!LDe)m^r+7e04vM-5RChjS=@VE z;JYD!J7-Uc5*ZEpLA{293*wxTKv&jWq1Q0#sscDl#mc zKk5&`Z^>3b(Ua3`kBnO!d2KI zFbd~{Cp>)I)quw6D{?3&%cX8+gz?C5n^_|{bP0@Y*CM@~(y>4Y(~{keBv+UbQl!{k zIW2uuS`BN#ig_C@2(u}1`{@abj|DQAyNW{B!*AEJiP|)xd0{`fB*Rf_b~7~e$Tc&n zWQ&7af$F|@2EQH#ep^o2+=9vwZs}W$c3w zUm+|xtFrUUsZIkfXnPmnVEa-lxBzsCqxhYHNowbG;|K3f(nckhJemd_4Q0o*y)v>w@pu@jm0W8g z8H~l}zNsuK;LMOIJw1! z$|4SydM6egelor@q`xbKx@Rd;+42&cdWPg!g=R5^nkI`*06xjXdI$!maxO&uc12CSE%HIfndTfJptP*G#47I(dh zE>aO)&RqC=t2S_82##(;)ZC(%(C?sB!cb7{f7Gad$?!DzPBsv+6MX@$g9<|m(MC|- zxnlU&cT@}c8jA5-s>fE6X*fOa!tbfBET#O(gYyVlXIo^pA*g?}L2&ER9{P8Nt1f0( zsRm9}5b+0H+dr9np;t>}4MBoRSR=_dEP7eXz4k)9*PfBO{~ch!8sw@LUyP89#v9ro{8#4xwXw zr+rlfoeECdf|`+R$FogEhG~UI35$SH$=jki-2KS0UmZbdBuOyHFumw>fQF;CXHVH& zHnpUIVkXiaG(1h)cVZTGQ3@vY?(a!WC^cF-+sN1$HH+zUL#`l*4?Z|kLS`UJ^r@Pa zU1)w`LD#DWHHbr1m&ShYwRZ49P8O z^1bE*Y9ASmkW(c_qeAYJ4JW%?`x|CfoIzDo;YiiP(A9uG$<~IgrxbPQjVW>&I=ZwL*E-ixQ4&6eYC30K*wdm zlJB7#68e0GB9pvr_TQO)t6H_64{=JB8ST9 za(VlDYRQO-ce%!1YF9~Bipk+48N%nk*%(3rs>5|}B*oMzDK(}NyFIR;40UqCHFm?7 z2{BB>ZUQ(9R8p>%d}E$N;r%u~3z?bg{&e7zJ2Fp3FfPAsyVPtUT4MHO8TY~j>TKYm zx{{0tjFk7^#&)0UYGG&JR5v#t#<9H@3=_3ipx!^^enoYsh_&$Wb^2vGv;(48UdyM3 zMWQiP*JnZH9IZ0c(~24a8Zu?u3g&@_(}C&nBG%kq6s{4EVFnmbiGp8F;S`i-v)g&^ z(bjFZv-50eNgVl$*abr$wQ9h27L8$L@8Yv|+Sk>Fj3^B0bP{_iT?3pMbU#;0*qmaqoKJ4TpB_>D7QG%KKG>DEB;I zK*rGPO)I?3l8XJd@R|OUgm1JBy!-9v4Bo{5v{U~7HIcfX-==-`UIjb3zHTWcv5g@c zc_we#`mfqsyApr=K`D@DnRw{TWqG%z?OA@TV&-|Bd7HMrEVwb^{4Rw@XE**ZV%(Zl z`!}=OPsz8H#f*DgUz;ISYo|ALN0ED>RHnOZY)EdL`zVd9IgGHIYZ>~AXs$BRqH8z+ zEM)M)oY9YP6^=6ErlsemNgY|A2Wm@&>)l@z zl6u2or<2MOl1R^6#BFp22XobT*Z$aJS!@&;&!)oVUuG;_n5XB^n^a+X zDYYWSE*>;<;_U#R8E}}LMn1m|WOgae4JWD$UcPkj0~o8Ma?`9Gn_|kDN~HkIY8BeZ zr_Mo>my<)bQw1kHX*yIj3W{N7Dd>ut2_$U$G57T*$ATLRT+p82V zC})5-A!_=avS&Cf2Hw6NI40>8-eP8|o0v287DPS8w%vN*M|Xqtc1BhDGqex#&ib3N z>ZrI?WzSU^GkdO^$N4})l^I9_9-b_m{Y%|K9OP~U0Di2R9 zK$)2Xz{1?QyFuPrSahmCE4b-pV_@-f}Y~ji>SOjN=tv zXx87;**tIVZttUcx5L~ApT%I<-+5J19S(k-hesUQ9Nkqh*yn+}nvr|=N-I*^Jzs|I z8?VH=YQYe!TX$}@bR^x52^Ilen!4=xZyh4e2m0IU(QeM^`t|FU$cuJE2%X*#b|FW}br-tCvh_BRns+jI99v(u{jgjL-AIzby z*{4sVrhYG#49COkxS1V;I;AqCp-$8$4P{p|sGypfeK!K4K7;fznM^|SUS29Idy7)~AQuPNAD9ga&2X-7Dp6*49 zv1vd4sIGHVhJ8D5!|BtflXuj5Vxk$j8#fV97<_zaE<%~Z9JTuPNecPnf+07 zKEDE-!JH%VdnZ;DYJ(Y$DUvrg}00m}~ zdyKNQ`NSf~0Wzz6ENzQg=uaS;TzUZQd=40jQ9Uq6%{KP1BKHsc{PKoI5A*WHp*#Q~ zaF5xIUIAfzzs{YzAV1yuWBAMakzQGT#(VjJORYbtAF>L+SAW(H_b}F#J6tn2F6KVU z8@^Z;bc!ZT`v1rQJ+Q(ApF8mI@Fcv7-xGO!>%7^SRzfvAyPKwT14CLR8Mz2!nSTz6 zKc6F7Lh7rU-qkzs)H4Pnn8x8x+ZTKn;=W8H$;|%4-8yORsP*f&0`S05rR*am&wc06 z>Q9;^=Fp_FZXL~R8Cr{9_zIqUZ~qL13S1foC=UjPA;h=(hIK?3*Mpld`$fSYTXu?Z z-?yB-`SJY*P;xnC6Hg!9HsgMX4!j|1$^g$gX~R(Ingl z|8Ph4^Xd98x76D)b)I>d-@~%s40g1+2Ok5u%SFNUs$M+v1T(n9zjt-;sWv@j_iJ6` ze%ks@wG-8T@SA8tPt!Db(QP60W+B3(Q@Sjnm=V9%X1ZreWh zjzQ8_+3&(61r{90KxaO(Hm=bQ&)VS73CyyB`S<-*a8iK8f*^Z&No9xT!#f3Lk6p`J z_vLd>+cvDm=t5xv+=7-PX4ZeUj^^b?Gui&;%k@FF;$Qy8e9J6+-^yLP3(#GzMBAYW zRTi9G%RF#oeSGXX@t1as6kHsqPv6o@uy^mCxpd#Zcz9sX&nS7x386pn;#uDFF;t=Q~l8B z++P`aLZ2pUn=U%Z-BZ-|q!uOY?uJgX>0P3S ztj)1?>y{!9dCVTp3>Jo9bPLhoXCIh9vLv8i){X=H=%9tqL}n23qKv?uIJ$b#=tEqzEjF~Kw ztVVd4uUMbMDtn3+jFT0aPtmli<`xLT`LckI{IzkLoC>UN@M@sF^fJFM(|!;jtE_rq zbUsJ&k~r$+3-GN*yOD66#n-q!MV|mZpz71`LWMtVq}&c)#PuAyg2wJ{_zrpKTZgD0 z5DMn`kX$&UmMKB}SZ%@181wxtb+4C`rIYmS#{hNr*?#`uZst(wXF^|Ak*phEmM}bR;W7)NrQT z*aYLnMtjM@wa;MO-G0xa>)h#Ekk;wzTFz3xdwHAMj~Lf4lG-$cg6iM#*qp&Isw@_F znK!p|$7_o;YI-C1Ws1ECZ#Eo5ZmE9G1AON=;7+9Vy?M?3m>wX~mImudfBV==nq2mr7KGDIX$=lm@3Jr<65wmLhr)oPOx6SYI_~#A`eBvXyT>Pv;6UbZaYS^% zY_J~A>rzlZ7pB3BPv8#&F?hV2g>G8T4>y-GoT_=Y`W(p2j#fC0N!{GT18FFQ$U{G z==gaA(DoT&Q%iii#dX;5?j8+Bx%P8Hsb~_qpZdVUSUrt|d@?RILU4PEE(Ez~w36*(@2vhLuh?+;_|A`Tl0x-}7U?% z^5CYJ+5M-?sRq7RdDp?x3qznveGy<7HBy-Y+uT9(`XVW6p1ipmR+&dv@%_Bi0aiY? zsu`Y?^A)LXiTG}w$!PN~6dHoOArv!{NQjgSF9ZqF16Kc;aw;}gVckG9`CUOyJ2Uy{ z#8m6Q`ETA+7nct<%$sy`-tfxg#FbkQ2nFLymT-!I;dbxdVT{f|YlPS_C^y(qqpj{& zq%tI+`f4r$mS;KiS*=?an9$cnz|jsZVzyZapFYh2+P$zkodpFPWG;WA%qfJ8*oO3Y zamsNizTl1;DgIzXUG3DL@+DM^k|WL?md)AvQ@&}&S z=5C{8BZ;}!U%QMC=lgX^=V4i$RP$(`>P3GWX)Np)oqTQWexbkfE8i!bGMXORrmdE= z=9>AwgI#hKvN(QWIHX?QJfIK@BXOP1I=cw2IX>6NY% zH{Tn-6zbkd0C#~x5JBg>8(Vkg<(nOhapeg%^a#T@7T`Z)D>Nnh-20D9%g{P2IR+`2H|6#pzTPdjsIkV^iN$F0gPM$y*W16B{dQGb7{S9?T_ z`^fwybmw?e-p2{!6spwm{t?kD>9dUTThm^l_2%5U+NqZ(r?d+HJt62|2^-HI{H^A- zo+9DNO0L3ZU<=>=s&gDN0XZ8W@^HRC5A1q2d(XpSp^OM+w;FfRd~fK0dn^I#)7 zcx`~I@vCHnNCs$9TWyJ?F2#4QU7+s@peCySW_y5UCkK}}ic)ETQbA86-Y^l)kEzr~ zmYl6bcnH~_9RBS7fH@5<(|I<*`PPBHOtYtIA71@3L{A>(DITQTXJC3rL>{Ol+|eV2 z9-q$C<(Nz#JL4>m(6^U2Lh*V|0T;zTCr1yG;wUVx9NYHkX`8F~Pw}VCepZHlN1Of} zH4LICy67*6+t1M~7hJwmI*lbSPcX*vgLF6E7-NhHaT1eL%I$u?w3q&lx>{DO-ArsjR9Ru)2! zH4YC%>k<9*;3Te3I-DKxggIYnH*#8WluiY1`1ACOYbn>)-raEgj?aWkCg!(Ht6sd% zx^^dE?f=u>n}_wBzi-1I^Brb}(I|>&u@tFfBvh6dC6yMmE73xWt+G@o!fRAg0PRear540=;Ur*FTeLf-U< zn=c|*oO2O1Mv-iK%Mt9s4jG&p@cTWf9B^9W^z`(W9!gn`%PU3|h?(Ofq#h4Z;=~Su zkpL@W5gB`Wp`XGyhb>Kd^A<~q;cIK3_0zzW+-S__-FQftImi3Apxs z{djGDe3FvAwt4n?*H%91SY_he`tsw+316-@Ek|{FlyFS`9B*ce4iYsenhYW6(anR* zd^kn3Wbr7X(gPh*i{g+V?vLgixqyBtgw+n4azz2cSw5W8F`VACMAttHFali|Y7wd@l zRKU=-{lA1YZVG&}J`1Cu;-gOS5Y&Fe<(Qvr%3WMICR zfX|bYHntg-gYDxAPX^Kt3;7fhy9=mNHtJHB6wk!y;XLTqWGG;ea+3N{mYZSjj^2g9 zA;V^n6&50#XRJJ74%esNOq)lHI|>gQ?GUyOk0c)6svUJyb-H=f!vovQ(P(0gCmM_v za7ADwn5Xx_w#Z-$fq6+$JEEz zS|dnf6NA&8gHWLOtcK0ztH!*_CY^^)>oUSYL|s4= za{c<)Q^7r{Z7##LU4y*{Qfe`J@0+dPSqjjigKZfX78~P?l#Yd zaFN2Xk;AoPZwS{8bA&A;92-0w(=S0^*fMx_?*==tQK?TeZa`!NHvBjWra3`Ukxeek zF)hOxTf-RT;yD~(YRVdPXSQ|>4h|OTgB$k_C}5WcX5$UOdPZ_W%mFwd^)F_alOU>U zC4U3KI5`Q}WD#62Yg-MFoR0)QD-qyz(T?@+<3RV8`aRzM0Shq>tuNLu0>-Ihtpq14 zDj3vcU+kxL^>Tfgx-a(aQGgqX%kDSU^yp#o z;Kir{jMi|kP)dS!p&Tld6Yy6;=@66mjp75!LB$@$oWZFEkZ(l5u4}eq?|uFy`yepZ zVUD~u07V)>(8Usf2nasLq4vTU9dpmu!AlsyUqx0VTd2THZHPdjYq#-GJi>gLKC(~KEC6y*e%Tlx6oP%Wo1KA1-BaqiFt%LG#9f8w>iB})kcCs z2R7oL1XDWJ$ z!i1XIA%#Jk;QaIQKYdAdpv35Kgih+G)yS2ze;rlFLjnI9VXdS~M9Bqp6}iiW5Nn zPHf++V8E)Rgl@g{@HF9=+OT4#TvXiLx-nQ?*|@u}2er=+EHMpKQUH3vqLq#8zgui? zxCT}^=$ClmX}QVo_8>vZf_&Ok;ggn&sdTg_`Ch{YpkaiuY}AXBYADNw z4+T!)Go?7CsNZ0z-&N8tjHa5!`og;9P3P>?hN?sW8pXjgcp`Wz#F*5BmHYd0NH4ke z*SIN7=TMukgvKbc*tomV*iD_&+4(1)3c`Y{&BFQnG4XZ)fRRzZpXBVe+@_ggCAcq!UK+_VE+o2> zUSKnfZJyz;g_djV*~8+UVxWddp3#fg8$us{cFT!?bXoxwhel|odL~u>jt88nA&OmX zx-QLlzU@B|u+@r~r@Dgnioo-to!WS1WkZVf>C$itAEKZnhBv4fgiK`dmP>QAqi@eke(?k^Ec#PbuQV(#Nj7|9AYQfv}g2cbQ-Rm!KN)C`x zaDsqScWL#p@tJU3zK6-}Y0lNW&Op6b(gL(z5tOk#6S_p=_A?5$uK4b)@GJPB?$eIdGXSSO2R!lCPai$_OP|nXaizwuUPu~PM&SHKV5~CwwhP9w?i$sL1 z04<&|+=Jbg$Ee4TL`NS*c5HL6*nUjF1ONXaWN>-FA3h^jBaO_thWcP>OCCQ1IBbI? zGsP$QMl-VS1sEk_;kX1`UJWo+2zL@r*YMn~7ZaUe+@+1JH8G5D;mgLipdog_)?f&l z($gxyqc)f&;?EW!Azci;3(;BHRLT>cKE;`_5737D>1|8dkE)%|DDr>J@5d?05*<{XFJ-ibqHx6#z zVut;476sM<6FSs)lRZHyyp}o(R^nP{+WdJD!D2i*Z+W8RPalXT4nbq)J(KB-EQL_B zd+4u0LAYb&;MEUL?Rg7?rm1ae98?RY)s^NDtfW(4S}9 z(-6=!;X@?@?0}E&qsb2i7YSjQJ2g)O%vk?Kl?B3;B2elLEqkx@D@LXr_y;3*2zKCa zj4VW3e{Th)7+AN554KfsVsyf}?A4E^5{ybbqVASA6dDH|lxiJK9-i@>jov&<#wX^H zlY;*$HVoekT0K6Xr<*~7^|@6xk|x&H`zAWbF}#9&re_ypGYBwc0%bQ^lU!*m>wgC& ztCE7u^w|l3PEY)r5An`W_JhM{RBeb?r6B3vKq8A7bh2Lsi(e(834`wL;ZtfbD(&+F zr|*!rKxh*N*lBUyd8a{dz(~V~{bB!iob*rR4SR@(|Bk#FUj7q#(|Uf5y!);J?%?|C zAtMuy_qPd;JTIkYH;f!jJg_OL1q-=OY?49{3{t2iONy&M_3mY9tUTG7EkHY&xUoZf z%6nk!3ox}vxE`Wt@{vRx1@~TNum>IoL}-GDv+{k6UR8s!u$yEcc=d!ouqs>cF;Uib zrnc@ZxX0IKQ^BnOk1e3Xv@qRa8QyR2d)NRCr4e9X00B?V!SLEuWT8ei9}r6zBd`dw zFHW$@QsAcma!zzahKjhy6|@$=4V_i1k0>?&UnMC*osHg(=RHiDEqlSK<N*I6@&m<2G=FK7&tab=R~E?7KifPNDtgv>u;JNHs6;nBU!FCCBFgXqn9< zc1IX9>Xck>pyN6y>ZJ*s7^uuhCPYBfYEYcSZNvGvs5X-&x5r7lZPvMFX&houNE#P2 z8+HZB!jU~?04)M^>%Y1MhE-*Xlr)gS19B+m9x3|q=z&rcyBqP1uOZ&HBcD1~NE z;mq|=_;DYn;*8w%j0pfm)LreLypnHnzU{BqO!{y_oIta(GJp;dom~`a-3F9Gf+CJ~ z4x!vt*l;2DF&-}tpvxva`jdd`J%tz#p$qUoVwf7?(h<5$Nt{!E?-+Ijk!bv!7l@IC z4bp)a;qNEocRV3ELQHY?Qf=JBNc34iGlFJ=rgQ9;<_4Tuj(YAoxPPe2z5qy@P@A7Q zY3&5MniO-EtsPo=6|+$bql98M>#DBZXe5(LyP;TrYX>7S4c;^r!M1me-0-yy&ET#~E6J){&8y}pynh^E`oO3Zw zT9%yFT3L)+DCS!_J|dQ_0JG0h z#6gzLx;x(@<7qT}ihC4b2_PHp&`$&5@fVN{ zp9(8jPA>_NJyWome}nAethsviPZ1WgD%kL`+}54Ul7RDcf}Ss&qXzjMemto%WANAo zL2#A&2;(h%?SNUS#3y&Ad@%I`jW|Sj*|vq48Z#N;v{mt7$#X0^;FyPD8m?v^(PWZx ztNq#=+{w+eJ4sC~hqP&;^;9hz!-2uxhx&IlcG2=dH)=<~k5IyO#2vXvCOdNs>KZIh zs{Snx;W7lzs2q;C?4PS=ypN+J@U7$oz?0y@zo2A+RQcCSO)UhNv>3HKc5W4Db0>*t zCYa$NS{TUFwn&;<`Cw9p1QO}>?y{W#jPuwv+BfdT6x5bpz@&{DaDhu+coINiAncS$ zcsN-)b+$2jz)Z!6!kX~Q>+2=N#Rci669+d^HwH*lB&BLS=u7%0=AjkhGEM9g^41o! ztVV$p{X+WvO#wxT32Xs^%VGbD0`46H@${}Iz1i9kmdlSV@!OweV;5xpx7DZ;{9kV9d3@l?)sWem#Bd8i9M z-&QO6>^)rVgbi<+U)+p`J+hd(WK;@1nE^5=Zp)Y5Fw7)53D@KP4hz2UV7RTa0xu*( zuIx(`1gw-ufAV$9|LeKK$sx&DJQ)=-Yn16g~n7_u_HHRKw2Lm6@jJH)h=$VK3~mH@W|Q%@|cG#Uy?(A7}c zOI*>~P#Hqi22R8b0+MPO)ssN62)8o(mudWxb0DDT`1>Mz39_(B#+De5UpBrddwti;<&~QTcyyXaxoIp(@+H}a7LDus%d{xp`vj+e* zOq-1A|0^?WFi^TV_eM7FB@LkTE&jr zoy-jsk;c>K4>~L1$J`x)AhV4*TgIc@cp<=t~WxF z{RV?Xn-H9>m}z+Z5d}aj{`zTF+sMX+$A$=r6d^vSo7~Gq5e(O<{1Ybf@vqU6@EjqY zh};;6qS|Yw>=z=Cn^WA)ui$PH61&VSw_T13yiAmvInSIorxAyk!TX9kho#A}6Z02? zShO%VH}A@2tp6Az52Q`M?ZH-p@x2RYBVbsv6vFR@J3yH+4Rm}CzCtq2t0(=_=-7I* z4h|c&<1k_Ja+gOy?0+aWN&1sFFlX7ydg;C(Kjf49Vq$=Mr1`G=A!5RYnl1DKQB-$Iij8C{XZ;28APD5wt!HsIyNo7Y)nJXseN;TuLP!w~ z1Rs502U(YzM~R<6vB!@ePzVA~iUqz*zKkBt4y>AJ`5$LBV{tuZEMl-lzv{##_EBq@!Qz$m7l7CmDiCRp^v+lpnGhSbzF}_kTH9}fW z2&um0Dn~U(^YHTVyCWHD(Mp1Z=zWfm#w$KM_lMb2_5k1_6$H$dRuzk1=-@!P8Tn%% z^H``D3^|^d?O^jenNA_ch9IeQ-HjWJFzXODt9D<&t}8|^BnH`tjJ#ez4OVd+f<%3b zaOwh446Vp{wvXpCddU%poe4l6OLUrHcr!1_M+&9yq;+4Qj)_u|JrY$Zzy=Fg;{gYxEr#$<{nJl?wQOIb zd^ETL@w;FLp#-TLA#mM7(qONt5V2_^^7|?jkxZO}NSMxT)YB%z2v{8YE>4pLqb*cna9HogCmX6YTPqfykxM#8O@xpY9&9>89J_=f zAPP~h;D#xNZAM*fE%H_&g!S7_3_Y{xNAd z{>HWIeNbR-MHp=a)*3U7@{z`lV}>8uVOQ9dDyK9HbjlM*M&x$dMi_{6K&F@mir_hr zh3b=vG?d!FZ>+TjCU4RH_OzF_d3!I=nf>ej>6_MYupk59R;9i!)Zk%veU)9|>yH4D zd&f_8+jc}z(E7rAHCu$IOi#J1?NMm=n@j zHsKL5YY_bJN3&u(&wNR7@l1eIO?2vZL}irX8hr;a-kL!F6iXMon{V5GVlKR#DB38{ z(l}70978|Y$6Rfo?G)K>PT6Unb+_4gN-2ACACdEGlnr?dk^KBDW2yj=00}}Neo})v zND?jim$TEQCx*i0fnSlLu=w|T-<&>I)>}LMcQ26y)#(^6DB5$J6kK-rR^(umJtJ;> zde+lDjrDcn;@FO`rb9{L*r8-Szi<_*Q=Xy~)m`<#} zhd}QOKaS7A5Awaj-=L2{8hG1}Q$RDJDoPyV9V<)aHluDWk6A6{{mDPKVq|S5dBO&|{!}js!IS~E^(FviMvnyIr z??RZX1h~<9rjVhB#m@`>WItaL9G-XQe3y@D^TW?S_A>f$*v})GK4+_H62fE+VVv}5 zAlRSeT(l=uzhBWmsQCYN0Uxe7w7P~WD5U3f;qo<35Z37G{wS-hfi#Zo7D&Q`iY%xLV=7OHeBs^yKh~8>futh?Mb=|z-XWx)2OgEN zsZID9OP(Ns>jTePzyJhk;}EQe9-jmI=?Ml6U}3dA+U$ZDgZNC4Ln#%vKNW!|2TQGq368TOclR-~BKjMVpOCNk?e)WY2 zN5BD;fiDaSqSwGn1T~Y4j?ahZ_?Jmubwv`7nA(DC0hF;*^p0_HnVexvv(Ab>Bu@f> z&$I88EPbTn4Srl7NDJ~m^9H+==6kVY@G@}Gl2lrdG2obz2dh#dUVL(&)s>oTsIoLs zr~4iTN0GQDLCXQ4J|Hd6$W6BJ_@nwlY6jDEtTYX$kOWzSC5qk@x3SfVD@0or@F{tv zSHfhc&r^$0Y1Pu{ws+hK0TaImpy0myvw-YzB~57kYEeF!OP{sqOC`1PO?1+S#lj8? zq1k0A0BR}%YLMLv1xcG)IGw96_ND(lnX}r>$L@6n^mrgb%V5{&4!* z0kwlF8mxgNg~^s3srsYBQ`@v8#2Il6gVJmUCc-qK>1aOa6M1G$cHSZSMYG3&acG{M za#d+x7ij?&4x&i2sV1L()Ccy;5!8rhrG-^Kg=uWioeI`f+O#D6b6ew7>H$e-z}hwe z@#Z0~#k@^)M;pHl`$!5<4aQRjUOx<@E>3LhX2u24c`edZ9HKCkX~~K^Y|`(1Q|NuP zEUoq4^e0zg^`N4TWWZZyDusoyfI^k|js`?;(XU-V1hZB?^@BDjG#n-Ei}k zs#^qFAUv5>(glf{B^Nk@`+;di;IN65e`;cyET=PN?QiJWL}Av-dAh)V zH7L>81thoc{wE$^#AeFtyOAJxb=cd4tyY>a;9vr!f&GuLVEFk-e%kUvF0LuKn=Y-F3ag}{iY zp*f4wn91AQ_Ly2y=w4(%Er;NX>0%LTIN{BE@8jUcsqh(0D%Zl%UUXU@&ZO z)~!PMnSSfp$_WDp2W@Dm*hO%Lh;}gVEJCWa*+>6?!{xEaESon%56Q8T87q*IW(JPx zZ%QaLL22cK`js$<7?SI6`PUX9YsZKad?bnz%pI>c9o#aIU-6>dnuN7|y1l*GGKU;C z+ohxGdL%PaNC$m1%(w~OruIanNiY-=`Ute+QEZh7&YMjetva-za9hc?=2V_Bcg`d< z5~kVZwN`}~3_qz9XYIWWr;UQsYd;sCRr%#Zty9_g4HYpqLG85Aqce6w5dqs z2j zYJ&CK59;?oA=LU{xt39E%V>f5Id&`-L$=Zfv5b+d4#4ZMpO?*UYM;LGjrGe$&z+Q> zlH%s>2F#j=^IwBj$-5zmS`%HmVRhT3DY#AOye=W$^J(*qhVB^-PhO_A9i1PcfC&(b zYBn$53Gab!09SZtAV(1@*TFI_{j-G~tJ<&k7HU?OKx;Qy<`g_N_4Q1^i+=;Ih!wp){_ zW;17MVebR{(j3#_2etial1z}MYI|0RK$6MQ1xQ_R(xeZH1Zvq74R-Tlg%$KV4RT zf9$WtOtm1bpoiUZ^RW)B0-h;bLzN9e9YL~;dF4$A17|S-Y;klmh%|!Q>_9 z36Wqe0JS}c${(XG(7+xZ%!Q;VmE9*?;df`W+?w2ksT|YMP+)12_gYl%P0*I`6h-XFgGP7cO@a$J1abL4lC1))xm@Ib^Zr0d>aCv!@eApAjH$h0TuTAUpss~w`7+_lq zh!gfO+e82T)2s0?!joXyeK3%SaEKdFQXVAjr$TlGgFRt-pCAx@+NM&$#e^If;;af; z#c<@bNG~x>$w`g$$K<{FdFdb3(mn1~uXf2Dm)h{E{h-F${7-qTETvUssw4ZsO-@+7 z025M!P%0RStaFRx>1R)kW|ysNwqr*}WObh@1fd(r-h#V=>PSOu=wfai*PT}J>Mdw6 z16>6qMBST#_=c{fWF#W(+(TtNnG<-LwDe^+m`jKkGq_#!>N2Dh5kd2*GOp}6JZo!* zg81GoP=V_7)I}<=NjUzIRFnFrT(kD{eCMF9q|nh}unL?6@|EhMOIp!sRa z!N;oY3UWsC4Q0=Cm;^2S1!lx$Rw1oxxbsiSV>NGs`Tqp7~lD-2xdJJ$3CI0vm12a_*P$* z5;P+x5`F(1wq+fFqsi#eS~yEaZ?J-;csY#v;+Hb0l@P@I3Zo++6idtdWQ=-2SlSqN zhRm1GXj{5Fh!f6~Jl{|&s;kR3_}y<>9GdJ#kA47Ul+l&HBI^RHDHyPDeALjLY73>b zIZe|T@3tSUd+BYbHEp@bZk z6+kMn;>f@;{uf$muuN8tk1lxD)*uldMuB)bOq;}}Zp4)q+;-^gh|ek6Dlz^N5`j+b z%`ZIRcVaFyS^*W+kiJ2%o)5ch*;RE;rLEA=NLzdL~bOcP=wr|aAMvhZJWZ7v`{Ef8BMw%r7d=$3fJ3CT6v2dD)xYY%ha zaAhkNA7ucng(Ak^=kpkMKI!q=#aG69%fm%y!qgOy_`=64dmd%wnYcL9|0{N{Y-?IF zF@rvz0!Y-U)Iwh$3`mqy(L|%4H2tg;>^^f}2tAk5j(dD;qxB^^=H;sa@X{NbRU%_ zguRT6oe%(kp=3aG)1g~JT*oeD;MDua@|$3Q zB%iJ^Ne7d4v#trqa@$9?rVxozvn^FClHs42jAqd=(}pz&*j2OVD+5!*=s7gVuxbk= z7do5g{gNliP6e)rZ74*KZc$La4iU42_(FG2lf0h>Y^KArhAG*#bG8ccC_^>YR}1R~ zRZ2p|_8iITRj7t@nbpm<2I31}PH6-qnWVRRibqu_)A2leV({jsiM@w3v3ON}W#|^MN-vH!XSEW389d??on9h<)u{92hLo zq8s(T@qA6oPg-?*78S0=Y=Rg5n?FrCa2Y-c1=#v|tPQC-KQLVAO-eXR_Fp}>Fn}Wu zExKMIHrz#bOF?Ep$GzAZh%+X%cteIH$=3CUA=vOl=uq&bX=r}NpG;|{d~O3jnu|78 z99^80zaoW0+V?#KTtDNAz@F-hlnrD?bm3JW!)tg#u^Q9Xt&)K_N8mRJs33QUvA(-x zZ>+3ND=S0M^y<>X0Olx(`>@xJ0Ke1ys_~~}|2yo(CWeU~5{2Y8Xk3DnI#(`%CoJjmSFOl+#NfPZ z^Di(EPAt7)?<>@4N#b{hq*|-LT4vmS=B(jHCBL%POPt@1ZmNxv33WrH0Sc1y6y+(dx8~EY4OIc%q{WGSQWtC#H6~#tUJNJ zqr2ZXRs;Q69BnCd$d!!EO)eZ?YX7pz`CU-Zv(N7~3Bqh_6*&@Zy1LYG2wipQ#8rQ` zvj=?Ll(brqg(7-m$jij~uFH9p#d-puGY4bocZX#{03TM3O;(TJW~W~EO)p8#iLu;N z=EhoQ%hI9J**u%zYTk^{8U zu&mXtIa9O(d`mCZXl3}ML&sTM`t*lQ`LELJJ(e2>+yRB8!j=VSd8Xv;>6G`=vq+ts zC2qQZ`afavg3G6?%XkHGh9nCS7&*IiD~3mZo{=%XCAkD}as<;#aP?%;Cm&)5-1$Q9 zb->t^)O=f&Gg*DNtGttAVR?E}T*@!m$$jTDlM4z>NQ`g^o3!6r(mG(Xzi#<9vzJ77_ zW_|LV4VV`_51(V{_c-|ck9WpF1?tPB0y*clx=2$;OqmVF1&%=`oGt%M-g?sNwP9e9 zPz?^g|7_dpr-=&=NU)Y7Nd9TFFEX{e@r{J)K5=pCR-oQL*_Iw2?PiSI_9Arh9YOrY zHVQ0h1`Da;2_KRVOe2IO^p+_bj-rv9GNz6Uf@=(>gn+YD78Raoe~5_a*>L+;RN9gi zVtzYjsh?Z6&e{z;ZB5K7jgD*c#l0lQ1Ek^+9h;k#Un-e>QfH;*NweHFyl2U|w9}IER(c+JYg*|Hk6+^&oI`;M9>jPV%s1woN#NWRmvrg2P{}e z&72PfWtqsM7Gf6;<>l^jV#LMhi|R3(G`$$fg|OoH@$VteO8n3)Mo5hyJwJqzE;Z6R zo5LhG%~WZ87swtNK|PoMcxsxqs*g~s*_`jdN|*fZ4N1PME$^6*j$~J#xIxtxhc2Po->c(Rde+q3$I0l7{u5Ue4lx` zwB-usdtr?#F5Ez)a@mlf6T?_ebSr<4id1*nF3C*?j^JXn#-&|=8abcE)i51i>^An(S`0U<+G zdhWwWE_HwOuw%m#dy5BGnynGUfPb9GVw9RN!QxIhlKYO_`ssZS0Yl-3a7J(XL{y%4 zTv!Mh3z|6sW8w{-x`#&jHJO=WcWr>>UL zLJY$*TrH0{Syy-0^mF@O!9BM}ZJ0^o!0T%E=DdTquiju2N1Wvl`1R>)pn68$9n%08 zrx!XzOAgRP7k29_tMeZ%jcKiFARV>^RzKlf&*M*x0Q_k4!W>}oru7ptz=juy>&u9f z22fpk@+^5Gh|&j0rE}ZY-rYh8OlyZza}5V zy_5H+uTpuZU{y7`ASej>i%Aq|K-?3G98HQ(CnT)3uw!9RNERT)75rY=Y9Dzs1XIic z3oQ!tb9FRtyS%|X|q=w^|Xiyw35#k>3>GCsfS+`kO zjGDKEyrA(eK3=;rXRLRI8DxB#4X5JS1-}WMBxa(XP8`&#YISYKEhr-fv!jBpE;jJ2 z{6~UVTIcbv>)6OgKo(ne`zPZhGAl(>rixIsq8lJlk4_iTvJiwOD#$f2a_j*%aY&yY zlpe+!h9)QmQYt*2Q-Q?X0zAio0d@gz=%;D2&mq5Y0XZz>lYN(Nmdqvq!gEK^6>!i) z(vfKf5kA;=PT)j$V>7D@QX6o4hBDreFy$epuHE(@F(vQi9|0Z>?anZr0(P(yt%v|u zjh_7Q3BlTK&=8UVTPSB_&Vrda5Ym!s0iJyiBMm^~MD~7H{tNc%Kmr8`L`A7fr4Z70 zqdG6D-LCp%#u8t1iZjy{W0^D>*cpo96;NT`dO&Gpj%+!8T@W5{$9% z!^t~=znj>&5V=HfA0Z4;m>YzGm3b_Is{}htuJ%6D?d)X-7C?t8eTow_(5|2KGu8VX zAnkc`!uhT@wG{Hp;mDD!6Ie(XlTqg$x~JSno4Hy_H2x5quF3=nA8^7rP$m-csOG^? zFe)FI(e&SV{eCjjMxGG;&!bM5&R)U55!<#2MoJ%;s}?=$?9gBpd_ftG466YSbY`HQ zfG~qfpRWK2_37i5qNYPY1~?Nf!*gBM_~!Iuq)Kj{dJ&QyMBkLN(%c9`2X0TPPW{-# z_}K@tYbQ(>&JVduWiV0$)VN4E1hS_J(5)P{SwI(D*lsr6W7UOOlwS-K6vZFa<^TK@~XDm2ESNha4!(V-wWCl|@Sv-)V60B0){1ajPY!&(RV zrax}H1itwM?Rf~duZBmJD$9O&KbT9P^sGkjrHVJo3g_uE2WWZeL+e40bY>}H8bM@1 zt7-YVr3Z}!hvb?N*213(3y52YPJa|^B05!cuTpJ~t^Em?K`B5jSr;Vh^q{eb1C#}) zA@qA&uzwH=T2(z0A2K1rUQ%q~`GO58jLtb4|O%A%{6jCWO4DuX~h5a{s zJQnhL67E>#+`I*(OoeO;%O^m6495Xl0;R!ZZZpo1VMDKkckk>@Ih;S+drBX1sD^UD zDxB0~Ue-cf|6wxCCu)DfU=wB%t_ZaO07C@pOxa<4wW1h;;nYG^OIS=2-8D?fLvI%E zv-dH}26!)ZyH2k=#ZCwG1AxiW?7+aQVci?^onWGzv*jAWF$rmTWCNs_m=Q$1So4By z=o*Q(ES_i-2O95aZtuHGv|X4SV#l`o4xaeNi1aBapxmb3vBAk=Meva{QKpr!C4*Cy- zZRV4NhCA@FQY@+C*k8{-22RUk!pK6c1DkuG69^1X?5F}l+-CtjTeQFwmf?}Y*rcC};6(Y6QPsTUe+li&S50~vII6^nYG z9g|Iki_n@-kQM?c-G?^BiLQZlo=jv#`@bV`#yliz!I>&4pq?CAfHiEJ#MF5q(4cuu zn`Fqg_xPs^BqGU)QHTz>GMJ^lO-Hn?D4mGMVI^_K+-EHUtZ24=ksF*ggM1fCTlhp6@${hJ#1olJIJ>Wh5B&y3pGVUSrri|#fYwb?I<6klvz ztB7Ak)<&Voj59=J)AFj<6CMTP$nhw5z1x%<0_MQ7=C>ErC>2tDI@Ess!XJ?+Z}d#Y z4s&uulpBPsOX^3lbdkqp$xS113zry8_@1N3`&*Tj$_vfaJrl>HuV)D@k1V#;I{ygx2n^|%L zMh<2(BUClbq9{-od*IW4oQq{&IN8MaAwcUTko$ufK zmTvZ6`MFlBv-zEtuqUs5_wOl>|0_TA9HjZe;PCKI8>L0tR&2NHBeGL-{!w&HYgDyJ z^iDTjBvKoy-tQCzYO@EMTU#rInQI$$aOx%1~RrFz)SN_)IntrRE9);NW{H%{6c{RmgqySF z2FuI8Z*H$%1$jn#>oBL#tgNgdLx(1Q_^@v0@4qXL9Xk&*xbM1L(H6OCsux^+-2Not z^l39>V-|x;86qQ-irBmXh-PEGm91^^$B*j~EUpd?4K<%Vd)w&IqoWQVP6g1eZnT5t zPgIo(!DDJkmLgZ>qA8yoO~hs!>8D@#$=gfxJT%mE;>)_{AzfGOan9CEnluSrxy~vi z#l`)CPb!B(@*|M?5;r$DKL9fFqZ6$je4Jo^$^5{91N{dKSgfo2TXRdxReK4MPv48> zK4^$>7apqO-0CF*yQ$eC$2XXYXD@;>s) zHq&wEQ=qiI{rW9klqR}6WUt4&>xX4Thv%ki$YFKNrmGjs+#k8aR9&=fY~rzF*XFAi z7Z~M z033<$*?yuAW3ewdD!BaKz1#Ko5|2DYDcM?7o%gSO^=c)elM#tqryjbkQI}E;a1pA| z0@Ca9o0R0_OJ{HNd9FUn>G4hbXZBzzoD>eErmE^M@%6n}|LA@{5P#G2N3qUTOsU0oVmJJu&8B-A73${#Yg>`THM!Sh!$8s2u7|5jB(~J;Q5~py^XlP|#hze0kP`RkoKmu1<4W zv2vvc$cFxSIyM^9(IKdVAocW!FR~4_G1y+KR*az$*@$aJb>(3%wv3m6iisl75WAydx1HOtmvy3J-9`IahQRvz@6IuT^HhK zV#{Fy&<^#$trboH^l zu$c0tbk9nI2uiJUR@aBQSMlDU(f0l_ckYOVlSFM&Uc0f+){XPpi60m2im-ln zZ~1mEj2DTlhWg2LnqvP-8>~S1A9fC4djIq1%OWEqYb!p%_*~xWjGOz)HN{uM$IkAs z7H;<9p7QcVg&yk7&CQ{e?o*iLZb#->cB`1&3|A;W)^6Xi<2J|+nYYM!Pl~vOcvfb0 zWo6|GDA?vj+^oGnV$qpNuN;aOnwhP;UYa=diedlRM!M-0_Gc#Tz8O9I#fujwAxa)Q za^$tBsHh6%8g1-7;acLfElYmk?aaHKHJC|z8n_MQ64cL#Sj z-Psg_`odubFCjTNwIRUO^%M^2d%eQqJ9qHPRjX2O=4555-^=!}BEFyg&?kDxwwK%N zQZBD>+8_C!ZNHDUYDj+|Y_Lg1`!9J+{as7z?2j|h_sX?&jd=7>6(*5r@qLYlw*b5CYgFWa*hWGTG`etCO>XhX1^58}8it zn!Fzvw={qHR1K$?g70bXbwu=hf^+CA9NHPaiz0?dNm({_pFVv$BnSGsFkUJ|FXHUh&W#4; zaGDw`@5sdX;}2`hYTr-a^sXY>YH1Rp#&_Fm-#<8`9|1S~DywwG8H0$8@1MS{*o4+V zYFob>IJ~aeZ};z2ss=E6xy!DKwiS%B8vYVVO|V;E;RNk z2vbeKPF5qrQ82*bM^t1w>C>ytnKLKiL|vUTpu{g=5AVU5e1uh-n!G~vK}!8(eM)+I z)!Km0C2yZzT~a!uB^05HRO06m)7Dggr_c6ja@_FdewmTQH?!}}PVXanyz3u>bTV8^ zOJfMeWL^uD96Hn}0*dpI2DC>CZZArrZ7+Jg&m0ijS=MDuPJa=%` zl%CJ+6O|Q<-Ltjb)5`81>?4Z#>0EuZ!ea(vh_qKYjY6t3&&U@IyAoB{a!9MO;oU9SwaPKH5#0rlcPxIHU(;lRqO z&==3()wqKJ$tT}}p^dRs2hx$QwE;aDwf<3iv;0oj=E@eqkNjkX#d2HSW_UcKN9 zTyP@ut6M6dn@WWxa)XZX5@fjw26(q91*H4ZmO7z^Z8ooG#xGaZU^pWcRzJwfcX3Ab$T!tr#>Tg>u+Yk$a(0F2+KV~q8n+OWPeDG@9@Gh(Ww=^>$VU_bC>j7R zg`I;7HCVJ}O}bIf!Kn^gCj78%-|o<+RQT_krKP4wpRIdv-v6_oL3XZtBl0+w^XB~y z);SSj`Shiot&98g=>z;`n6?B)Z~OM`8ZIS6j1Oy;eID9RWc=-yd%0Y=mt&k}0D1Tv z?jzFB>f5*Pya)pmlZ_9oEG%5mi^R3-^Jn8_%Z|ZfZhUj*(c6;&J>3P*o;{;RTdioiNy+SU69k2BHikkMXW1mWHF zZ@>HQ5UNbH9tQp>?Wk}7oUG=<ASXGw{+NM2mD>2`s}3@gxh_JnQe7heuvY-hLSdJf3^-S@I9?q|7#{P5$C z8UQ;d&z`MtbJ()wN~X2yOL+W^XDuu(4}x6ZyIgcoM(`FFyPnxETD*9-0+8aL6|+{{ z8aAfb*~LZcVad&##!#LcuUmK8$e*ZYUc^PP8OFGoJ3OWo>@O;se-9qcNLxyDu)kLt z9M5tSG12f~yQD9Eq+p~|uOGRqZXN3OP3eWcBF*uts^(y(hYTJZk5*(x+MgZCVD9MXc(tcTv@K(l)#DfyUtiyb>Wo-r6_qL2jvG+f zlwb-!1;07bzz;kWi{u`xn)(@Rq#1zWU+LDAH(Ta8Z~N;{ML_q0h86N$BBIUOV>Vmua=aQWN2h`kI>1#%SV4R;uYkSPT+bK z$Bj!t-jhZTGI$mCTNL7zwKzm5S4=&2Y$O}u!f9+w@_=g{x^=q=8YcL6HR%I$ZyI{l zMC6)WGzG}>+L0nj{CMs)3Wy=)`eC~PFW*`u0!mHEz13bX1x$kc$r_+U7xZv?{&B#8 zSQRz3shB^gL8RpNje4lnHtZzGbK zua2WVwCCsB?GCjS*$F34R(Z&XK3LQRcY3{G0(RNS*DG`1biwo80@-^Dc1z z^aJj*EblffQ{#-%?$4dK@i${c3JXwCfMab(b)*y4rZA-Hd(lzd0-sNpjTSA6_@RIQ z8@IKRgk7&K2JRt;3^>Hc! z%U%VAhsWeQdd|&`sRq58jOpsSFb@_u6`&9hv|hEB5@~qU-Sx%0`AA8xm&5xX)fdxx zbj9}!V-jd+m0z^y>mWH)me+FR<4Kl?;o~X zZl-O7xxIZFo_Q5HZ`)k|tLW;dhhuQ_a$mn|(3?abH~?h&v!1T5l8(-^M~Z!Bhs+V% zJC`VN5I+5>u9yLoV)pGk} z-%kF@n5`DDIUZs&&1}%kMj#(3xVOzSu7sb*khE2L`uZx9Cog8T<`obeXNT1d(!yu{ z(r@n^F6o|+g*x2}M~xbV;K{S&yW5xu;(^i~LsQdgBGjwj^f@LyCa`D}F}i z*DMS{okPP)7IMBk!=Ype?ll_`zcZxzgziE?Pb{Co8siYsJQ^7-~A+K4yCjG(PiEc5g0_Ewyf%uw@5VncNy(K-m)<}?u^>; z+<>+GD_qgGN_oPBN@PdgSNMN;K4&J{di4MPdovunMT-_yT;8-6-6zc;_q>OybeXM- z7cZ8SmacE}zFn1zh_f6#&jvIAP(skb;1XC*FvxH2FGT#BYGe#sF&u=D0_#^VQH!>B z?{VaTE`HoCY|DkqmmfbgOI);V9B#j8?Wi3h(bI|uuOjzsWo>O+Sb4S2)u}eM!+sJy z*Q&d##Vsl)yHd&QbF4;I)jWLK?W1~e+>R#wyBPGA3ekI|8cTZpn`-)_A8 z1Udd2)~Zt8D4Vk_U1DZtmQh-hsIIOas=)qd{ZNmCnCa9;d@DPDXS26OBSjy6^m=^- zq0#SN|dNh#Y}dS*txL(4|29{ zzFH^}RmlI>bG`Jvm}r)omX^)S$D{p^AbU;%5o!aCfbK0%kfKXYPv6*(i=ct06zMhu z6R7wXf1-|DjO)O6*2)+NPsQz)!7NaZ})#ds(R5t zViOgO)EMah6obN^&$3LhKzv&UZPanZm5Xd_Z2XWR+<;29`4Pwlo`A$uNl9rw^-?&c zLxv1dFhGVqW<4B{-;u=B)CJ}3hvGWEA27h%t)*r9(79l8(mrX5ww(<5*Rk6_7k={| z!(HU>UPFezq06ZO;Dsb+>+7<)s34C*A&*z_D=hF&TDoH7T;@0rxG|lkE z9%ecP)}Cbh8gQ?RH*P$Kj`{VkJwnSU8So3Y_ujF-?ZfgJAWL#hjBPxmB6>h}Rp@1Y9AtwDz*q$`>ke0uRKQimfF~J` zuLtJ^c(!48hkgn4obHGlm6ept%*`wGI@dnP(lgqpx>w=tvs{OWDELLZpJ9Z=;K9-R z_xlvLibM_b{&sNiQWfRi)L~$x+e!M=!VRcSG{7hjmMK)6G#>F8%_DN^Y(=7>|GezC zu{T7bEsN&N`E_qyR$m~&`}eIUIVrP=&>U-5fn;Ong;o+zKOO-CJw zPu8F;&;55qaWvq7;0ySTK%(A8BaMbB{Y1x>7aeb*?^RWQJwyF_<6rZKIjMPGRqDlC zfB%0*Uc=jeCXx3(yv~y#>A(E>(C7a=)Vm=6?Mn1~w@6d*f!_XaeE1{-JrmvZUzTnE FzW~1M8 Date: Fri, 17 Oct 2025 12:11:44 +0200 Subject: [PATCH 18/21] Remove plot Signed-off-by: simondanielsson --- latency_plot.png | Bin 61629 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 latency_plot.png diff --git a/latency_plot.png b/latency_plot.png deleted file mode 100644 index b77e138b2839e600c04f3c8376937aa609aa4e6e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 61629 zcmeFad0fqT|2}>&#tehI8`;Yck(NnB+H5gtwQDz&P((%hGGnYEPL$S@rLwfBR9cj= zwUv%KH0<@vgv*L6Lg=Zw1QhVf%( zjAbww;}w5jx0%5hYsX-W+Wymz_!sfO*U9+LN^^ys=37kmnp^x~YQRwW!~C$ZiMjFq z-E)o^n40Z3IV2$}{+sA>;W_)v%@3Qc6cal*^bAoGQ$sQ3(0N<&E;X?H22E)u^&#KFR))=wf+M+@tZ0??>;-@FSSj3Zg2K` zx>Iv<@BNKV3pH7r_2zEfH2>ubk6iWUz%^&~ZM}2m2Oj2WB?0}I_Vx|-(JcR;&$!I% z_VG&G6BO$ao>|<`yE0#=JvzHJ%_e`Tq@*O>+Q7qcZ%|~_h{2y2jGasT)P{a?9P8Wi z^UzP~BQDDgKU2stTX5)Ct46DL{V?>?EyhOH&<~6$tEwjr{p34xq20)#pSCjo-}#z3I*aw|j5RzJ2?WKD>J(-|F@HP^?If zzlxUXq$_H6c6J#n%FcNmc|LvM=GC7KjF#iss9lK}^XuX|yG#1*uQhWX|D7%sqE`+$V^!r;$_!XXeg!j6OGut+K+E2kX)YZl5=O< z-Ip(4mS2#r?|5mZCS7+QKk3&+KaAvm&{_uCkLAjC)qq0z5;mUHf&ST*ofF`hjuiGv4Z~8yhsE4ta0ep|cEkbx`zSOLm&OgyE%{AIHzQ z9FeDax4Ko+DCi2_*FOX!D`J{nTNrMZJl{<}wk}f1O;jg4=4ee|eAoLzvrJk1uiduv z+_*g3@=F_**mbq^)EdewmBibHXK!3=kh^V1`k~l5S3@7Gy=Pdh@!j>QQSq%oE%k4t z7YE(nywalHYrJH^oJ2kwV;}pjYX4^Z(uAfmwl*v^AhjpV9`%>xUM zDexxdXsW7o-DQF3N)IEs#<=ItpKIIHMJO^GY}wMO29bW^mLJ~IkU2Z_?Ae=kH8iPv z9AB?}(yyn4*~yhQy{-ApCH=i6(N))-n$jdP%X`~P)MAg-Ejjq=zD!%e9;vRX8{tnT zNSEJmmeZ_?$?%pAzcf|KsuB#&M0$JNw!bSR44P1h7Dq z@HR~u(hW3ONusf}OuR+c7DMl)`!`D#Tb9uzyZ`pbngd2g68Ny9qxm{m6bDU9?Ak@z zVlWGP@warkwV5Vy;vorr)|wFquADr1GOi_CX}Y9w!J2lt_qT2^1spn4^_seBLL5qE zc4!|_m@#u^czsNbXkTrvO;zIlGy5@JM_$$>Mx^mVu$l}(gOCK}N(=hZle%P9KA}v46Aa&N^MoEht<~-FiyuJRV1ps(2#Y|y0*ys3US7N88%yX7uVr&++wzksSb+TEkD?4&=t;K75cE30o} zPG)w#Jw9=>o*rw_7XR{z4jpTMoG`1vyw7*_@up{D_TuBmkMDfbU>{YWQJLnRv$a6B zDQTYDdzlVl<*{z5x?~jcP%^e5zI#oqSulPvUB>e8p@iNxfn)HT}k^Q^7SE4$%C&16jvUry*f#C#jj6JO;lrg9nI5lm$s9u?cr8*_-yuOi5w3mzgRqBCEw7 zTeYdq#6dctGu?zYrarc$_YkwQ|9oY-XH=4~MbVKBm<{jy`*RkI;kP}!sDzhpe&-(_ zgSN`6$b+6H_RXG(y#1%!9mI}H2I$C5EwE{Lm=n4uQKqk>jFt&gAFCm{+TU}tme%V) zrQ@E)VTyX#>7njaLpD=KZ0nv8gmW`t&~lwN;Ve$KVErgPX(%oDd@?bb-gIBybKy$y} zm8KvQyT#|vpI3>s!;U=fE@ig&SkaNhK*W+08Ov?jYPGGa{R)pH&dzyg(@w*)Y}qmu zJr_YWRc2ggr6sB_ic>i)ycU_bovwyLr=x>N2MGpPFiZZWe({22FVC}t? zH;%~Nu3;1ukzG)573=BPwRYaxzUc1y*nlkos-u;o^sy`UG4(q>yxWgcO2|QUh3@5K z7ePZ870KkR4gB8S4GHot3sh)J#0-Qjk-I^Wng@*7{d$quO&w`939WfrDh@s?j~<|f zpw@F>-@f0OXvK-VI1 zr1YtO87BbEm=RWNU1RXVcQwmOhwfx&*2k4J?)L55`>oBp6XZ*KF?R>8dybje-tQji zc`#BVzQs?h;{5XBq-Rt0yaIfzg}kLLGO+D7RcQ-)9bTi0nV{o4x+!I`#QyA?wk=s3 z+?N<^EJo0Z>ur0j61xuw$G)mf$xJ-U6j4t=@#OfKu4^Yq3DAmI!^X#n zVX8zr`TFCF1!W_9oA1Bv^DDOb=MZa$u9d-z88bYN^%~>Im>}2jHY}Zky{;$^#O<9= zUm1mX|2b;Rg3g3Da#MR6fDuYsPB&D(moDj4)r>TDFAO&nT-29Uzr|lUIg#JNOQ3My zjq}(}*lStl84j8s3fjwEEJg&~y_@P`6znb#-JBuqW>Ca5j;?S`HE%G%?yG)rZS4Z( zs}~b`D#W6Dno?Z^u(>U;ST;vW(Ru63Tr_~+oQ~2d;-2t?U6pmi^H`7e;W#O+gAut+ z{e30Dqe&$$0N(z?Mo0J)nnf|Ji#pSJng;*|on>Wa6j@vt3Kr zY}JV^0~mnU2mw{8CNHR&j-(H7|50R$*;Ok||83W!e}{NW7^?cKdUxH{!$M#9`lypx zns{P44#Cv@0p4xya_xqVIPpY}6k9b{8XyQ~?1*o=B9MGW(B~?hH8>m;nF;;94TXr; zO+{vzesFgo)JF@BnQg8fiG5*ezc}p8vDXz^C0*N@oh98aYI0I0@tu2fy0Jvkr|i(q+B!kU zYqYba!~5F-OH7L60;7tpr!U>7uE8d(?YPl@0GtcwO^yF>q85jRB9rZ7BdX|eY_~R( zHjYm0vEG*<&+3dd-XB$v0*G;2(V!&u5D@9@U*obdqRL(?#4*roUqu+XO_;q>pTg~v ziSk~lVhKH4`1i#7{>v&$VJZ!WYKZQ`%{zAFZ@sU~q$9b#*v2HXV7F6Te_yvsv)#LK zdcP-LnB9Cr%=TB{G^Jn!&NLs{*xC~3lHPiYLOLQ|ytrtVAtfkl)2584=G@*AU|#oX z*@jt1*7afI*p;^G^)$9DT|^%JRTRGyukD3w!7P{n$wp{ zAh1%180VdjFfP;hbW(I*S4}vervCY5c}YOL5(xEx;|A?-EHXz=)xL;d1=g}cF_6&+ z^UZAeekowe9WfRSJ~*>ih-gKo{W#P3B4CiHo6fPqfZ)b-6MG58Y=0HcS+*bkF<=cu z6AhI;cx0`mdt$6&*5a}|n_2x{z?ky+Iu<%vWt?AxUmuv!RpGxKDtZtYp0N^fJZX=! ztZPtIaZJ@EdXle3qooD?g)z1rCt!Ki4QJK0Yu8Ga*qWW-WNtgpxhhQLJYiL97puRM zMJOocwMl|E;zNKzsw++YS-x@f6p#7R1KU~sDyE~dHD6B++ehN?+eby1vBKR>B7;x# z^8lWHe*F~ZjXD4KHmC4^A+cMPdVS0-UPMQ0K=Mnv53Q%ja?mR0BR)u3Vnya)q0+3e zQoD|IX{*xnq|7QO%l^oCacFyGO zbgC8QMdh$f!*-p@M4B^IW;mn$U{#eLQo@dj_eNstai+^Xe2cuW5W}2uXOHk=C_cY!TP7{oBtO9H8&!$0AOAj%s;0#c50v0k9}R_4oJ3 zeSCcEy1md&QUd2vDnO%Wbxwd`TWz6&^Qe1|jgaVU*D$z3XW;&<>k0#FhhFwb@Xr4G zSISMZLi7?iu~WX$3(>3s)0&C%ATI9t$hZCZkq2M5yw6o?98Ag`nIbAX4r>%thflY{ zvZuHefLvJv!JI8uu?_!G);FfNPbD^aV@N@{MZhOnl!Arikb0aG~2`PY! z0HF$C)}0@gG8V_bdiBb`B+BGr>vn%NR&a~rwdZQymjrItWXHc%WZYW+#kBcO#=ES; zZrPjZ<1MhiMEB|_VYPAg4H-VN(g0|{4T~)2-pSn)qN-Rb*2b8;785ji`VXrv(Tp+I zGA=(3f8u!tIP()04kU z&u)pq#NpU6?CW96vrC=CGTaS)G>Q&0gu^eHycATFtS!9#LW*%q^{aV^Kp43k92qG| z2Nts&$(YGXIs2Xt42SU6f9bphN>@}-Q6YLkA%h?J%z9@&Z|`HT7A+2mGo>_s?L^t7 zIRR>7R?m^FYMUm{FT@q++}PmkpP0>Wn)EA$WckJzoST*Rb5xaZ!enD{QuZ8;?LO6} z9{KL2c^6j9r6U87Is5L85U<0Jf40JWD*4Rx{pm(=VCJpmHBMZI{#wh=i%)M4(iCZ) zLD_BOs{?lhk_n{}Cg`LC`}jn;2j$BQvCj5fkXj&Nn2{D7Tkk=6mZwR)Rq(^MSHH|! z@f#L;{q7JDC1p*>mPAWBo-a4W>>2cQG!_Dco8l(K%!}pw8l8W^p#-dWl+MLj$fc_D zL-a%+y#0&Pv(An(XC)sFoTt_2D}^Y2YZp*Obl3a4E&^JIqAQk!v={7gHYm9F2ZuZm z0BTL6tep{?DEHSqPOP<~`V~na5rqbo=AU$%>X7HNEnfO*ht8 z05~fn(PIh_UhZmdVA1^{j@VGAz@mc#LNo}3evgx!CgSh`hhttq5!_sIO+3;CDeR$U9GFBwWB{Uj6l+<$ zIXcxQz4Vyq>=lQ1A+s&fe|c+DAt2$B-A_*oBo~D4kpK~>kYR#U`}w))TbhdK^Er=n zka$}!Vx@{LFFL%IZWfu37a&e_N2!wp$fhEse5rBvrzXtimC6gN=+T>mc)oalI-vf1 zN53s4b!YiTI0`NJGVEaOPq~FzkhN|}0{&b=2_^CoA}p=~Jugr0*>FcNxhUFf zYV@(yf0$ZLUunJ%!Py5?km;p2(pYWJDXnsPcqrxaA`OIY-j$M8P2~dpJ5+V<#a4r5 z^c*!|g>TV~QNk9-^Kc0K#Mmm1Q*x93KoI5}n`8adZXITOTj*j^eY~vS-v$L_2Ts6K z)qCZ9MF8L+g$##xC#JgA*p@v%pNa?@^ZUDIZ;N6Ybx%J2ewGhjL^~xM!05_P8*lUcWiI@$u(ldgU*mvm^1efLQ zoCnsex~c0HBek}5zeID_eRl=t5v#7?h@GtP;{Yh4GuKbO>+dO#+>o@}mdz|lLgwJ* zc_?O)qA8_U*vcvZr=rp2Od0Gw!sE@jyIs*4#CuoZlTs+DefQ)P`}s_hv<05VOf12$ zYVj5)31p44a7+@B#0k9E${B#c9z=;>YmV=DCbk!Bk`&_d5`blG+s0?nxT7W5gSHJe zC0g~Yse>E!ZS&@w2fMUZ$?fRvYE@!_Y_tw-SiW?re%e)W#Y*|$*vRetQ+HW4N?)zq z-H}t~A~0rr;9~O$j5{u~M*1#XGH{dwR^HM-Dj$8fiiassmRT54#j}e?>HS#I#X%gB zjPqe%3MQJoviSGrgD-Au`PSi1sB*l`F(K<872(G`9?96I{1rM<{U1p)FMwkhjH@P|M`L)ozAur>zV5dG>Z*7c zw|0L8Ciogu!zD>lfl7gsGxdH3@t$-P?*|o!QOagy+=8|r!?a;thXuDFH^6lbgDZ@ZOMXN z$Nm2zZp18oS$;WA%OsJgA|}%PI*A~w5C4o2YqLzekbbY@`DP-8nF2d>j&4H^*Zfz& z9|~;t&_UM`x7E11x$U({rsIfFgpN*J&;yX-00boJGQ<~XMrwCCTS8X3VztrZIERx9 z1K61&^K14p3u(0g<-qtZ*B5_1tDq^HqMlE2#IS-p!agWG@Sd z9&r%>Ke-+A>A`0uYmjoe2dFcXSIhWPuud$JVo41gxMtupDlv7s>u~a}Ai;oW8}r%%@w(ShY6mW;p0CUuhb@O*jv zCBV#i#Fx|O(riQ(nJQ7JEV3xEBT)zxj9ZSHyetrL z9ym{DAB`u$6VlG0n(jV?Q`f)L5M%0@A={iLVfJc&fZqmZJ}FEMfU{Q-Nb^DnS7y(q z>?nV+cH+U6?KEC}9f|ULLH}|{GZK_Y*8~qsE1G@NMO9PZ6~e(@qRF7bnAV0SR#peF zGAi8ke5UGJ{;DR|S-vKrCZYRHLuFX1sFE|Etm)f7$M}Pe=>!$yPN)}xP~RG}!6>=P zPvhN2R=&2G3c#8Rgc89T&?ts6Z8(Y20V}Y_Jt>*ukT}Qcx};Q6MiTQ%((5eq7puqB zwJo>p_(0r^lUF~e#|5N`*mvcyDxnD~Bi&(^BUec!_hBJQr%R?1=>l3!_Nk9shZ%W%GPs=3z&oO*g1)!oR)5po#b&;l)PTZeBvZA2(<@1uS z7Lsa8_e$h|B9mahq4&}4nvp#T9q!9bbdNzvB>70H^6DPmMSJDIfOSIpx&-Pmy1L*m zQ#PR(BC`~Uy)MU?r{aUX!38a^14p>h;nn^9;~oy?296uIHu_~k?Ol#x$T&A$qKU#S z4M}_b8#6FB%cop%&BMQyozP_=!jiPg`^xW-B65JJf-=Mq6_%WxT#ig`DJ>T0pJvae z8w#8E5TQ*vk7(11YOue{tBW7nkh<41JIm!nH18+zY?yV4*xO9GZexhz_g{2FmeAc7m91}I z-3F;X4J=E^)~+~+)J%aVC&z2ueA}s5;;`>fTUE(JE@z@<-+O!09a;1Ad!%AtiJ*O)xgxA3nH4 zWZEJKM=A{}9!qaORa<9P$rd4+0h)c}i3{UrEZui4tb$e@M$GJ*S&Rj%f9eGS(|c~L zKfjSl0~xz09)}m2|phVS0BvJ7iDr72vo1u-dx4 z{$W8oQm)7sHMk)dWh&`KbETK{#bVnDH^viJJ7L9JNAsFI5v)J6JZ-b|`v+f3Dw<;9 z>O&G{z1yO`s3$I6%uckvxyPouS)P9_>t^+(T2h!+wt=t}P}~lkL?J`(u<<`qil&GM zYz&JUxg$WO2N>lCld4M@K9QoKC8uz3inf{2CsLLaXVsj2U4cWp+vjg?d~cYcS}NAT zg$BD`s>B*+4q1_s?s{+-#AQGfI6%54B~kN3tUzN_7|M5RP^?7w5(|@WO$8WZI4^lP_<6nb;0APX8Dep*_ z`f6WaPu%?DV=+x|Ra^j4gJ8I`Jw$A@D={);t zK5MfE86+q%2Vuw&x%LNvgak7wcxHp0gAai-WsjwcK;!9ESz<9#o_XTJ(W)1})&$2M zT7Q1V6)>EZ?)B10EHG?`uOaTpbe6j~$tFOWaT2-w5Q&VET_)c21YdP}6Pq0rM>LCo z6>=zEs|FkadR1N+QOf7QlGLm@ky-ANFfq9&vjTFtw{`l)B?1! z*vZQQi^;I5Jh}7TM281Ss;6=HRya~!-^YVMOKZ-QPf&L^&c3Ng5t7%Y55c7A=_ED% zq6bG*68)9kncq0qpp!l32#U^76h(P37+LrH5 z?9UI@iElh_BH-RH`#Z0QWgjx+bO4*SGwoL;4{8$qb>c!Vfk3bg&A=eB)de69-32IX zox12uJit||cG2iT_;fw!H)o4}3z$ImF_+e5Yuo{4Z-_jNKBud#PU}L}&U3^o!U>T{ zA%)yPuZ_c1o+FVbqD&9beyUYq8{xbS{B}#-v`rQ36Z-ooEOb7e8h;TBY+j~Jvoi4y zenJ+rabk3mNuo9sl6GVuIfClo0aNJM@r4EE62&hs3AnF2Xgb1|2n%xbdkjpA?Y8a4J$`(gv z22QvubmK(Kf$hvSFD|xe@TOdWB>6lYi#N4^!!(@FFUZ?HJ|prF4j~_Ly0`S3!g(z zqryf@$bKZgnL&lDq=OdLfU6LIZrZfh7h>n|b~W_vS6%V$ZzvU&*-B#RWE&va(##%)6@O6VGY{V1WZ&e%(g&p1s44=ZCrMly$cYO* zv?H+C+;INNfW+u~nr#T!WBSkNspSrRY==E%pvJ!SuAWF!Y$q)PqIbx3!l9}u zd;$0nyd%*iFLNkEOx60RE!u9U9FlBL2+H$m`e`ANtVc)VO+^<JZbilA;ZtdG1>#V~Tv&;#kEJ~&d=)YF`4SYH%PiZh)bf%mgFrY%UJ>!LjmLvk?(Djav-$dtOfGPQ~0xf~>f+F&Eit zC`pvNqOEIl7b!AhMC-=+{NQ-tD+1wGvpC)yWo`YfC)UtxH%@0{5<-%Ee)GJm7zr4c*%#P%!cE7{tNR$HL!+hvI6!D$- ze%^J0+=78H4^_X4xclbHq9UJ}!5ZISj5GK8llF*J8jiq9$0rUk1?L~`J-&1McH@cx z;TyLsbTZpe!vEY4j9i2AnXQUUV$irSc8rAz!~`>bD}hWT@jkulTQuzd#FI?YHG{D$ zUf28el<#n2)9&j=UEjd)J^R(dvy77-oIS>XgL>ZrzUdND&ga~pb9w`vQ&v|-RVz11SAS7P5}C>n#I^EDD{K7{n-0T8$RX3 zZ;8+~3_e^xy0qnM5bVq=IB2A{{zZrKfeMO#=t7Gu-T|35S*Y0b;i5Pm&^|^2B z>${a$Y?gyPaYkgpc9*d-m4o+{V}FMj+f1w=B*;EPO~CfzR1_1KzCIXLL82p()zC>7 zuRZyTA}Am^*2EEv4<>Czhts)-h^@6{ZyXw(^vaQ$E5bWR4s?BFp(-HFiCNZvTRN$W_f#~Ff*&gs{R@UhCTa!$pI^7%y-jfx{CBSoM{KYaVX|kod;nn7;i;udKirU{VzRRr?>94;~r&%B3nB5 zQ!Xd~(W~{}nYeHpvY|mnS_zBFt2GS~0A3`F!~g6?M7){%l%WCIwLkE%Unap_9Hz<0 zdWlaoJi|_aysO9r3bhPv5L+R4;N`8R@x$x)!}JL4rbX$m0kS6wy8R%a1Qj7)hZeL?oyCG)u{k{l&Yfp`{4i0&LLI)Kb_YV38;wF%MdC?YE%R^8= zT?8`a`+Mm5w#UayF3d4!Fr3=IE5rs4XPq|z?UiIKAY7}0k)Rwz*+)j|O-ye|ULnCE ztW$Sh5luq+okf+S8NW001JZ@L9w!P4^(~+c)GCpJK?i-ZCXo$J-AOcNAu-h3VLY!z zlBCiw9u|G%_UXxjthg7;i~>W~x0N{b!#^oaD$NpXKmT@n_?Qi$#*+hbs?hP_xfM2< z5}m^aS>!=JW1SfDWF+=sK=dS8TY(o_0+Te6RFob-mfk-%|ygLCWz-*%ZwNr*k5 zhyeL`5$xa%n?v$3J7hMY?1fxi$QhP3B=p-;^%4C$UePy8ekh}1JYsGuS7niolp)t$ zn`xf@u^dayHUR_^rCo?|G3RoIqx)q&1y=<7?d$%>5N4BB%kT{{gxvuIR!JOz%o+_G z!yPped|M^a*tfxqJ1Y7d&hHuk_d=R))ar$-|2`*gn2(T858tr^rg$ZV&;k0#cPC%N z5D|m}k&ed~%EsZU^a?IGl3u=a(ovmP-+=;yHz(vc;C1{WIosEH=W`3i9M}{na+8y7 zGt4%FuV?%>Fw8*l#lS0t2GO!?m&LcS*D8su+d4Je?{Z+cH^HC@(o>Y{$#P;~))Jje zDlSYd5&GtH25;uSfh<4PEpT4Bbr3L)NFZu5OGgMsjVDnw>Lp!j5zHLDyv_AM4nE$1daJ*+GEr0(_ z?;Y;uoOQSyEPC8g0L|6l>8&r7C34B`}GFZw7PkMyzoYE+xV6B zV<07Hm-pDhJEG?WV~N@1>>n$YnK{4^i=+nN9%$ibKi8I48! zC76qD8pel3awp zKn{O@W2o6q-reS#42IIT?JL*kDzQ-U_F=b0g*L=94*G=xnntD#vdtf@$s4Ww6XQc+ z?$P~vTu1NhtLtMHM)mb}3o7nS_qYO%E$Z$5p(|#j+?Y$Y05G^pOZs^}(2=K{PjcII zfIXTq6MS%X?+wMDgv37Oy<8flj%vy*ku&$&-;|9!ZToa-O zDpcD%OF^pT#acJ(^LSs8whofM!OUbX9jFYT>v}YD@yoI|gP{q@|p@2!GtcA!Rh7}A!#PfzoZrPLIP4_Pt{9ww4=6$p$>He^q#d4*i& z_P2G;q{P@vFn6TF@L5+~BaIakq?H}h8ZhP^44&EKVX|#X9ydDykgFL63=z}5Tbph) zQ8fw}G|J-K9>_5sh23FENw-QEPURXQu49ePk~H3V2=6bTNTs3=8|+L%F=WM7!~in& z0VyfY>fTc}9AxC~o)xy6Fnx&tGVDdE%gAl>=+9B964QpJ_!0Z&f8O!nBMmbf zEH?Kj{59j2{VcCJ`F+TKbghg^Uc<@WOx2KN@Ffd*Ma=xc1;SXByHKmTIX|KA z=mgo87xxe3-l;BtD}v)`r`UjF{4#mnsj%4xwe6;Fcp1A+d_SqW9MsZDqBKTJ8Fnv^ z3p$_@+~w@7RZ*@-o^(#!1BXL$wszEL)K10qXL->d?)wdID1`(V>Y2^8JqX?=blCp7 z_T)_*>78Bq2|h@BWk?W?KEx@B#HfCq0#lTXGtqb1kNd zrmtR9ivtY?g4~ua2#85kh0D=Fo8{0$;l4>pJYDkOu(vI_PhoI?)3YCf&k-=9nR&Cp-5qyGtOFfOw4rvNgbuJU7fG}SOHZL1+%-@Vk zaR8cN^SZDo{~$8sX|DrwU|Hs7&BaBuz0X}tehw-tBK3yUL#+R)E7>Qe>LvlExT9(y zYGAlueiH$OU6FIS30W1=Vv=CI1Dtx500hEpk>Q1UXU#`F20XeUyw=>)hfSu{W|$^a z;a4tsV+$XmJF2Re1P#B9`uL5acHp~~fl{bl@}`HnLXdAq`4h%$>v5(M3xh(%-<2a$ zLWg(LqqD&kVcI=dW@J^=@vj+v9^nb7N@H@fyzHt5;L$4$2sq`AXB5r6a zjbv=Sa{R}J;cf?tnd1RmzIIl7VlaA$?FwYA@J3{40!`f%b%!H`5WRt2zdD*kw)K%iVu zzEe>bI;25NBI<(8h`=Z8NkTq!bbp{W3i?E&lBKOD&4`>Vo&QXZCiDH)n6w;{kiJDB~IwJgtXC#&zJ{Ql~#|nOh#luL^NDq$b zLXh632*V_f7NQfFisDZR9BH^I1gmWzePSl zV=$=LMw*Id!_?vBR5r$9S8YKk*;$&YU&!nl1Tg9;g3*=Zl#QS;!`6S97FcwY+%F-s z2%%z5=m9cD2bbsiH_l4}{l_tpbJ^d{mxdO?@#Vl(XhGI82+86oLMbsLe=iG-j_AfS zm+}vd$*O2OQNK*=c8*aV>+d3rILl!<>-VJMKy1`Bdi$RzEhq;?nwB5m79rDKWJWGl zxtuQQR6;(CN+f}$YZCL(b%Uxw6qzWhn;&kNsavWEl%GP0WnPdL8;7SN2FJ5l8gG#U zhZ%JVV(n+E4_%I<>0+bmW-0`t3nj1Hm|2Hc(V&orj+{rR4|K|lu>j>0!q1ihQ zY~w2ao^@4{GpMF=C@U00<-kt^YcZ97XK;!d;k{v5>yRlOGv{h_jy$3$^LhmG*y?l5 zf~vbX%`*T~!qR_Y+*|Q&<_CMl7WU&XG|yq*x^W>T3MMClDYt{Aky60p{Q9!dIZ~Ju znY0ZdF=>=`Qi~!c9RyBx4*xV_)`W9=;lLcW{JI_(CVL(Gw?<~O>2w`2>POqk8uaA3 z3PHiu;3V`IX|jH2qt%rc^eLsI>7W>{3b1x=J_x%cr_Btt>v-!z5LGb@(B?$Oh*d|w zqRoI(!Y@BIlp;og=yMu7axJ4mN=*O4CmHCyN3aRFbU4W}o?ZNsdn0^AmA7rkQmE4n z%QAxs^D(p1&a!PgqYrrskWrqx8BvZ%S{SF)>2?Sqn6Jc;RJnyi>uuGUl*k+^J$OxB zU#K7yh1=to7yX9CqeoV%gOP>)!tH1+qDkgW%F9p9GX0$#(3$3TU+E*MlAO1`XU@ro z^^>|425cqwPKVXCNTIBetZuBBd3#Zq$eCCa1TRXJMP~nmx|*btPnd3ty6K>hLoR2= zIKEN+a07|8^^Q(l4)dhJ%;A!!UDGcg49KLa6RKhTcV-QuJMkI@@g?6#e_2+XkPTax z6{s5t)D^(l>ePk1`PQ8*=Y0cW^UR~H`ZGREvvcbfS= z%kg4J8^_%1mtiB-!slmIb?p=>-&fTGhJBv zI9-rd9DcjxjS-6Nq7kTFIHF8@eEY*^K1ih5P(wr&nJ{L701$Uyx<7j^3SLln3)JC( zz#>IT5LLoqVX~?0mfCJnYpB$vVvI-G)9(8?=#N}=q+I>VfeEo&Ub+AKj?bqR=bNyb zuASoZKvet<5{(u$n=%edUNG)qKN1wq8Ai=%FifI353!aX`pb(gbP&ye#^SM&#q!6+ zQOO9E7?7Wm&LA=@Q$4d(**OU+G%@fPDy;(*cO|GL8qQ@lRVd&nx=bJ$I}0ir=oNVU z)1AcPsm}o&{bY#@S8ksH?-@Q|bDj~EiBolD)r)JXOJ_1x-6sFb8He1TSBZZ`5g-;R zGf}YoD`-P8QOP*Naw@Ei5?DVklVM@}aaiEQAU zl4G240*X`tiWo$RK}Fm!;f*ZE2!ban$qf+b;Q@Wi4wV_+ySXT<8v2RBa19KUw#)EX zwm4tU!4BXwl#Is6A2e&NQ0qjMU#V8ZAk}e-CeY0!LDe)m^r+7e04vM-5RChjS=@VE z;JYD!J7-Uc5*ZEpLA{293*wxTKv&jWq1Q0#sscDl#mc zKk5&`Z^>3b(Ua3`kBnO!d2KI zFbd~{Cp>)I)quw6D{?3&%cX8+gz?C5n^_|{bP0@Y*CM@~(y>4Y(~{keBv+UbQl!{k zIW2uuS`BN#ig_C@2(u}1`{@abj|DQAyNW{B!*AEJiP|)xd0{`fB*Rf_b~7~e$Tc&n zWQ&7af$F|@2EQH#ep^o2+=9vwZs}W$c3w zUm+|xtFrUUsZIkfXnPmnVEa-lxBzsCqxhYHNowbG;|K3f(nckhJemd_4Q0o*y)v>w@pu@jm0W8g z8H~l}zNsuK;LMOIJw1! z$|4SydM6egelor@q`xbKx@Rd;+42&cdWPg!g=R5^nkI`*06xjXdI$!maxO&uc12CSE%HIfndTfJptP*G#47I(dh zE>aO)&RqC=t2S_82##(;)ZC(%(C?sB!cb7{f7Gad$?!DzPBsv+6MX@$g9<|m(MC|- zxnlU&cT@}c8jA5-s>fE6X*fOa!tbfBET#O(gYyVlXIo^pA*g?}L2&ER9{P8Nt1f0( zsRm9}5b+0H+dr9np;t>}4MBoRSR=_dEP7eXz4k)9*PfBO{~ch!8sw@LUyP89#v9ro{8#4xwXw zr+rlfoeECdf|`+R$FogEhG~UI35$SH$=jki-2KS0UmZbdBuOyHFumw>fQF;CXHVH& zHnpUIVkXiaG(1h)cVZTGQ3@vY?(a!WC^cF-+sN1$HH+zUL#`l*4?Z|kLS`UJ^r@Pa zU1)w`LD#DWHHbr1m&ShYwRZ49P8O z^1bE*Y9ASmkW(c_qeAYJ4JW%?`x|CfoIzDo;YiiP(A9uG$<~IgrxbPQjVW>&I=ZwL*E-ixQ4&6eYC30K*wdm zlJB7#68e0GB9pvr_TQO)t6H_64{=JB8ST9 za(VlDYRQO-ce%!1YF9~Bipk+48N%nk*%(3rs>5|}B*oMzDK(}NyFIR;40UqCHFm?7 z2{BB>ZUQ(9R8p>%d}E$N;r%u~3z?bg{&e7zJ2Fp3FfPAsyVPtUT4MHO8TY~j>TKYm zx{{0tjFk7^#&)0UYGG&JR5v#t#<9H@3=_3ipx!^^enoYsh_&$Wb^2vGv;(48UdyM3 zMWQiP*JnZH9IZ0c(~24a8Zu?u3g&@_(}C&nBG%kq6s{4EVFnmbiGp8F;S`i-v)g&^ z(bjFZv-50eNgVl$*abr$wQ9h27L8$L@8Yv|+Sk>Fj3^B0bP{_iT?3pMbU#;0*qmaqoKJ4TpB_>D7QG%KKG>DEB;I zK*rGPO)I?3l8XJd@R|OUgm1JBy!-9v4Bo{5v{U~7HIcfX-==-`UIjb3zHTWcv5g@c zc_we#`mfqsyApr=K`D@DnRw{TWqG%z?OA@TV&-|Bd7HMrEVwb^{4Rw@XE**ZV%(Zl z`!}=OPsz8H#f*DgUz;ISYo|ALN0ED>RHnOZY)EdL`zVd9IgGHIYZ>~AXs$BRqH8z+ zEM)M)oY9YP6^=6ErlsemNgY|A2Wm@&>)l@z zl6u2or<2MOl1R^6#BFp22XobT*Z$aJS!@&;&!)oVUuG;_n5XB^n^a+X zDYYWSE*>;<;_U#R8E}}LMn1m|WOgae4JWD$UcPkj0~o8Ma?`9Gn_|kDN~HkIY8BeZ zr_Mo>my<)bQw1kHX*yIj3W{N7Dd>ut2_$U$G57T*$ATLRT+p82V zC})5-A!_=avS&Cf2Hw6NI40>8-eP8|o0v287DPS8w%vN*M|Xqtc1BhDGqex#&ib3N z>ZrI?WzSU^GkdO^$N4})l^I9_9-b_m{Y%|K9OP~U0Di2R9 zK$)2Xz{1?QyFuPrSahmCE4b-pV_@-f}Y~ji>SOjN=tv zXx87;**tIVZttUcx5L~ApT%I<-+5J19S(k-hesUQ9Nkqh*yn+}nvr|=N-I*^Jzs|I z8?VH=YQYe!TX$}@bR^x52^Ilen!4=xZyh4e2m0IU(QeM^`t|FU$cuJE2%X*#b|FW}br-tCvh_BRns+jI99v(u{jgjL-AIzby z*{4sVrhYG#49COkxS1V;I;AqCp-$8$4P{p|sGypfeK!K4K7;fznM^|SUS29Idy7)~AQuPNAD9ga&2X-7Dp6*49 zv1vd4sIGHVhJ8D5!|BtflXuj5Vxk$j8#fV97<_zaE<%~Z9JTuPNecPnf+07 zKEDE-!JH%VdnZ;DYJ(Y$DUvrg}00m}~ zdyKNQ`NSf~0Wzz6ENzQg=uaS;TzUZQd=40jQ9Uq6%{KP1BKHsc{PKoI5A*WHp*#Q~ zaF5xIUIAfzzs{YzAV1yuWBAMakzQGT#(VjJORYbtAF>L+SAW(H_b}F#J6tn2F6KVU z8@^Z;bc!ZT`v1rQJ+Q(ApF8mI@Fcv7-xGO!>%7^SRzfvAyPKwT14CLR8Mz2!nSTz6 zKc6F7Lh7rU-qkzs)H4Pnn8x8x+ZTKn;=W8H$;|%4-8yORsP*f&0`S05rR*am&wc06 z>Q9;^=Fp_FZXL~R8Cr{9_zIqUZ~qL13S1foC=UjPA;h=(hIK?3*Mpld`$fSYTXu?Z z-?yB-`SJY*P;xnC6Hg!9HsgMX4!j|1$^g$gX~R(Ingl z|8Ph4^Xd98x76D)b)I>d-@~%s40g1+2Ok5u%SFNUs$M+v1T(n9zjt-;sWv@j_iJ6` ze%ks@wG-8T@SA8tPt!Db(QP60W+B3(Q@Sjnm=V9%X1ZreWh zjzQ8_+3&(61r{90KxaO(Hm=bQ&)VS73CyyB`S<-*a8iK8f*^Z&No9xT!#f3Lk6p`J z_vLd>+cvDm=t5xv+=7-PX4ZeUj^^b?Gui&;%k@FF;$Qy8e9J6+-^yLP3(#GzMBAYW zRTi9G%RF#oeSGXX@t1as6kHsqPv6o@uy^mCxpd#Zcz9sX&nS7x386pn;#uDFF;t=Q~l8B z++P`aLZ2pUn=U%Z-BZ-|q!uOY?uJgX>0P3S ztj)1?>y{!9dCVTp3>Jo9bPLhoXCIh9vLv8i){X=H=%9tqL}n23qKv?uIJ$b#=tEqzEjF~Kw ztVVd4uUMbMDtn3+jFT0aPtmli<`xLT`LckI{IzkLoC>UN@M@sF^fJFM(|!;jtE_rq zbUsJ&k~r$+3-GN*yOD66#n-q!MV|mZpz71`LWMtVq}&c)#PuAyg2wJ{_zrpKTZgD0 z5DMn`kX$&UmMKB}SZ%@181wxtb+4C`rIYmS#{hNr*?#`uZst(wXF^|Ak*phEmM}bR;W7)NrQT z*aYLnMtjM@wa;MO-G0xa>)h#Ekk;wzTFz3xdwHAMj~Lf4lG-$cg6iM#*qp&Isw@_F znK!p|$7_o;YI-C1Ws1ECZ#Eo5ZmE9G1AON=;7+9Vy?M?3m>wX~mImudfBV==nq2mr7KGDIX$=lm@3Jr<65wmLhr)oPOx6SYI_~#A`eBvXyT>Pv;6UbZaYS^% zY_J~A>rzlZ7pB3BPv8#&F?hV2g>G8T4>y-GoT_=Y`W(p2j#fC0N!{GT18FFQ$U{G z==gaA(DoT&Q%iii#dX;5?j8+Bx%P8Hsb~_qpZdVUSUrt|d@?RILU4PEE(Ez~w36*(@2vhLuh?+;_|A`Tl0x-}7U?% z^5CYJ+5M-?sRq7RdDp?x3qznveGy<7HBy-Y+uT9(`XVW6p1ipmR+&dv@%_Bi0aiY? zsu`Y?^A)LXiTG}w$!PN~6dHoOArv!{NQjgSF9ZqF16Kc;aw;}gVckG9`CUOyJ2Uy{ z#8m6Q`ETA+7nct<%$sy`-tfxg#FbkQ2nFLymT-!I;dbxdVT{f|YlPS_C^y(qqpj{& zq%tI+`f4r$mS;KiS*=?an9$cnz|jsZVzyZapFYh2+P$zkodpFPWG;WA%qfJ8*oO3Y zamsNizTl1;DgIzXUG3DL@+DM^k|WL?md)AvQ@&}&S z=5C{8BZ;}!U%QMC=lgX^=V4i$RP$(`>P3GWX)Np)oqTQWexbkfE8i!bGMXORrmdE= z=9>AwgI#hKvN(QWIHX?QJfIK@BXOP1I=cw2IX>6NY% zH{Tn-6zbkd0C#~x5JBg>8(Vkg<(nOhapeg%^a#T@7T`Z)D>Nnh-20D9%g{P2IR+`2H|6#pzTPdjsIkV^iN$F0gPM$y*W16B{dQGb7{S9?T_ z`^fwybmw?e-p2{!6spwm{t?kD>9dUTThm^l_2%5U+NqZ(r?d+HJt62|2^-HI{H^A- zo+9DNO0L3ZU<=>=s&gDN0XZ8W@^HRC5A1q2d(XpSp^OM+w;FfRd~fK0dn^I#)7 zcx`~I@vCHnNCs$9TWyJ?F2#4QU7+s@peCySW_y5UCkK}}ic)ETQbA86-Y^l)kEzr~ zmYl6bcnH~_9RBS7fH@5<(|I<*`PPBHOtYtIA71@3L{A>(DITQTXJC3rL>{Ol+|eV2 z9-q$C<(Nz#JL4>m(6^U2Lh*V|0T;zTCr1yG;wUVx9NYHkX`8F~Pw}VCepZHlN1Of} zH4LICy67*6+t1M~7hJwmI*lbSPcX*vgLF6E7-NhHaT1eL%I$u?w3q&lx>{DO-ArsjR9Ru)2! zH4YC%>k<9*;3Te3I-DKxggIYnH*#8WluiY1`1ACOYbn>)-raEgj?aWkCg!(Ht6sd% zx^^dE?f=u>n}_wBzi-1I^Brb}(I|>&u@tFfBvh6dC6yMmE73xWt+G@o!fRAg0PRear540=;Ur*FTeLf-U< zn=c|*oO2O1Mv-iK%Mt9s4jG&p@cTWf9B^9W^z`(W9!gn`%PU3|h?(Ofq#h4Z;=~Su zkpL@W5gB`Wp`XGyhb>Kd^A<~q;cIK3_0zzW+-S__-FQftImi3Apxs z{djGDe3FvAwt4n?*H%91SY_he`tsw+316-@Ek|{FlyFS`9B*ce4iYsenhYW6(anR* zd^kn3Wbr7X(gPh*i{g+V?vLgixqyBtgw+n4azz2cSw5W8F`VACMAttHFali|Y7wd@l zRKU=-{lA1YZVG&}J`1Cu;-gOS5Y&Fe<(Qvr%3WMICR zfX|bYHntg-gYDxAPX^Kt3;7fhy9=mNHtJHB6wk!y;XLTqWGG;ea+3N{mYZSjj^2g9 zA;V^n6&50#XRJJ74%esNOq)lHI|>gQ?GUyOk0c)6svUJyb-H=f!vovQ(P(0gCmM_v za7ADwn5Xx_w#Z-$fq6+$JEEz zS|dnf6NA&8gHWLOtcK0ztH!*_CY^^)>oUSYL|s4= za{c<)Q^7r{Z7##LU4y*{Qfe`J@0+dPSqjjigKZfX78~P?l#Yd zaFN2Xk;AoPZwS{8bA&A;92-0w(=S0^*fMx_?*==tQK?TeZa`!NHvBjWra3`Ukxeek zF)hOxTf-RT;yD~(YRVdPXSQ|>4h|OTgB$k_C}5WcX5$UOdPZ_W%mFwd^)F_alOU>U zC4U3KI5`Q}WD#62Yg-MFoR0)QD-qyz(T?@+<3RV8`aRzM0Shq>tuNLu0>-Ihtpq14 zDj3vcU+kxL^>Tfgx-a(aQGgqX%kDSU^yp#o z;Kir{jMi|kP)dS!p&Tld6Yy6;=@66mjp75!LB$@$oWZFEkZ(l5u4}eq?|uFy`yepZ zVUD~u07V)>(8Usf2nasLq4vTU9dpmu!AlsyUqx0VTd2THZHPdjYq#-GJi>gLKC(~KEC6y*e%Tlx6oP%Wo1KA1-BaqiFt%LG#9f8w>iB})kcCs z2R7oL1XDWJ$ z!i1XIA%#Jk;QaIQKYdAdpv35Kgih+G)yS2ze;rlFLjnI9VXdS~M9Bqp6}iiW5Nn zPHf++V8E)Rgl@g{@HF9=+OT4#TvXiLx-nQ?*|@u}2er=+EHMpKQUH3vqLq#8zgui? zxCT}^=$ClmX}QVo_8>vZf_&Ok;ggn&sdTg_`Ch{YpkaiuY}AXBYADNw z4+T!)Go?7CsNZ0z-&N8tjHa5!`og;9P3P>?hN?sW8pXjgcp`Wz#F*5BmHYd0NH4ke z*SIN7=TMukgvKbc*tomV*iD_&+4(1)3c`Y{&BFQnG4XZ)fRRzZpXBVe+@_ggCAcq!UK+_VE+o2> zUSKnfZJyz;g_djV*~8+UVxWddp3#fg8$us{cFT!?bXoxwhel|odL~u>jt88nA&OmX zx-QLlzU@B|u+@r~r@Dgnioo-to!WS1WkZVf>C$itAEKZnhBv4fgiK`dmP>QAqi@eke(?k^Ec#PbuQV(#Nj7|9AYQfv}g2cbQ-Rm!KN)C`x zaDsqScWL#p@tJU3zK6-}Y0lNW&Op6b(gL(z5tOk#6S_p=_A?5$uK4b)@GJPB?$eIdGXSSO2R!lCPai$_OP|nXaizwuUPu~PM&SHKV5~CwwhP9w?i$sL1 z04<&|+=Jbg$Ee4TL`NS*c5HL6*nUjF1ONXaWN>-FA3h^jBaO_thWcP>OCCQ1IBbI? zGsP$QMl-VS1sEk_;kX1`UJWo+2zL@r*YMn~7ZaUe+@+1JH8G5D;mgLipdog_)?f&l z($gxyqc)f&;?EW!Azci;3(;BHRLT>cKE;`_5737D>1|8dkE)%|DDr>J@5d?05*<{XFJ-ibqHx6#z zVut;476sM<6FSs)lRZHyyp}o(R^nP{+WdJD!D2i*Z+W8RPalXT4nbq)J(KB-EQL_B zd+4u0LAYb&;MEUL?Rg7?rm1ae98?RY)s^NDtfW(4S}9 z(-6=!;X@?@?0}E&qsb2i7YSjQJ2g)O%vk?Kl?B3;B2elLEqkx@D@LXr_y;3*2zKCa zj4VW3e{Th)7+AN554KfsVsyf}?A4E^5{ybbqVASA6dDH|lxiJK9-i@>jov&<#wX^H zlY;*$HVoekT0K6Xr<*~7^|@6xk|x&H`zAWbF}#9&re_ypGYBwc0%bQ^lU!*m>wgC& ztCE7u^w|l3PEY)r5An`W_JhM{RBeb?r6B3vKq8A7bh2Lsi(e(834`wL;ZtfbD(&+F zr|*!rKxh*N*lBUyd8a{dz(~V~{bB!iob*rR4SR@(|Bk#FUj7q#(|Uf5y!);J?%?|C zAtMuy_qPd;JTIkYH;f!jJg_OL1q-=OY?49{3{t2iONy&M_3mY9tUTG7EkHY&xUoZf z%6nk!3ox}vxE`Wt@{vRx1@~TNum>IoL}-GDv+{k6UR8s!u$yEcc=d!ouqs>cF;Uib zrnc@ZxX0IKQ^BnOk1e3Xv@qRa8QyR2d)NRCr4e9X00B?V!SLEuWT8ei9}r6zBd`dw zFHW$@QsAcma!zzahKjhy6|@$=4V_i1k0>?&UnMC*osHg(=RHiDEqlSK<N*I6@&m<2G=FK7&tab=R~E?7KifPNDtgv>u;JNHs6;nBU!FCCBFgXqn9< zc1IX9>Xck>pyN6y>ZJ*s7^uuhCPYBfYEYcSZNvGvs5X-&x5r7lZPvMFX&houNE#P2 z8+HZB!jU~?04)M^>%Y1MhE-*Xlr)gS19B+m9x3|q=z&rcyBqP1uOZ&HBcD1~NE z;mq|=_;DYn;*8w%j0pfm)LreLypnHnzU{BqO!{y_oIta(GJp;dom~`a-3F9Gf+CJ~ z4x!vt*l;2DF&-}tpvxva`jdd`J%tz#p$qUoVwf7?(h<5$Nt{!E?-+Ijk!bv!7l@IC z4bp)a;qNEocRV3ELQHY?Qf=JBNc34iGlFJ=rgQ9;<_4Tuj(YAoxPPe2z5qy@P@A7Q zY3&5MniO-EtsPo=6|+$bql98M>#DBZXe5(LyP;TrYX>7S4c;^r!M1me-0-yy&ET#~E6J){&8y}pynh^E`oO3Zw zT9%yFT3L)+DCS!_J|dQ_0JG0h z#6gzLx;x(@<7qT}ihC4b2_PHp&`$&5@fVN{ zp9(8jPA>_NJyWome}nAethsviPZ1WgD%kL`+}54Ul7RDcf}Ss&qXzjMemto%WANAo zL2#A&2;(h%?SNUS#3y&Ad@%I`jW|Sj*|vq48Z#N;v{mt7$#X0^;FyPD8m?v^(PWZx ztNq#=+{w+eJ4sC~hqP&;^;9hz!-2uxhx&IlcG2=dH)=<~k5IyO#2vXvCOdNs>KZIh zs{Snx;W7lzs2q;C?4PS=ypN+J@U7$oz?0y@zo2A+RQcCSO)UhNv>3HKc5W4Db0>*t zCYa$NS{TUFwn&;<`Cw9p1QO}>?y{W#jPuwv+BfdT6x5bpz@&{DaDhu+coINiAncS$ zcsN-)b+$2jz)Z!6!kX~Q>+2=N#Rci669+d^HwH*lB&BLS=u7%0=AjkhGEM9g^41o! ztVV$p{X+WvO#wxT32Xs^%VGbD0`46H@${}Iz1i9kmdlSV@!OweV;5xpx7DZ;{9kV9d3@l?)sWem#Bd8i9M z-&QO6>^)rVgbi<+U)+p`J+hd(WK;@1nE^5=Zp)Y5Fw7)53D@KP4hz2UV7RTa0xu*( zuIx(`1gw-ufAV$9|LeKK$sx&DJQ)=-Yn16g~n7_u_HHRKw2Lm6@jJH)h=$VK3~mH@W|Q%@|cG#Uy?(A7}c zOI*>~P#Hqi22R8b0+MPO)ssN62)8o(mudWxb0DDT`1>Mz39_(B#+De5UpBrddwti;<&~QTcyyXaxoIp(@+H}a7LDus%d{xp`vj+e* zOq-1A|0^?WFi^TV_eM7FB@LkTE&jr zoy-jsk;c>K4>~L1$J`x)AhV4*TgIc@cp<=t~WxF z{RV?Xn-H9>m}z+Z5d}aj{`zTF+sMX+$A$=r6d^vSo7~Gq5e(O<{1Ybf@vqU6@EjqY zh};;6qS|Yw>=z=Cn^WA)ui$PH61&VSw_T13yiAmvInSIorxAyk!TX9kho#A}6Z02? zShO%VH}A@2tp6Az52Q`M?ZH-p@x2RYBVbsv6vFR@J3yH+4Rm}CzCtq2t0(=_=-7I* z4h|c&<1k_Ja+gOy?0+aWN&1sFFlX7ydg;C(Kjf49Vq$=Mr1`G=A!5RYnl1DKQB-$Iij8C{XZ;28APD5wt!HsIyNo7Y)nJXseN;TuLP!w~ z1Rs502U(YzM~R<6vB!@ePzVA~iUqz*zKkBt4y>AJ`5$LBV{tuZEMl-lzv{##_EBq@!Qz$m7l7CmDiCRp^v+lpnGhSbzF}_kTH9}fW z2&um0Dn~U(^YHTVyCWHD(Mp1Z=zWfm#w$KM_lMb2_5k1_6$H$dRuzk1=-@!P8Tn%% z^H``D3^|^d?O^jenNA_ch9IeQ-HjWJFzXODt9D<&t}8|^BnH`tjJ#ez4OVd+f<%3b zaOwh446Vp{wvXpCddU%poe4l6OLUrHcr!1_M+&9yq;+4Qj)_u|JrY$Zzy=Fg;{gYxEr#$<{nJl?wQOIb zd^ETL@w;FLp#-TLA#mM7(qONt5V2_^^7|?jkxZO}NSMxT)YB%z2v{8YE>4pLqb*cna9HogCmX6YTPqfykxM#8O@xpY9&9>89J_=f zAPP~h;D#xNZAM*fE%H_&g!S7_3_Y{xNAd z{>HWIeNbR-MHp=a)*3U7@{z`lV}>8uVOQ9dDyK9HbjlM*M&x$dMi_{6K&F@mir_hr zh3b=vG?d!FZ>+TjCU4RH_OzF_d3!I=nf>ej>6_MYupk59R;9i!)Zk%veU)9|>yH4D zd&f_8+jc}z(E7rAHCu$IOi#J1?NMm=n@j zHsKL5YY_bJN3&u(&wNR7@l1eIO?2vZL}irX8hr;a-kL!F6iXMon{V5GVlKR#DB38{ z(l}70978|Y$6Rfo?G)K>PT6Unb+_4gN-2ACACdEGlnr?dk^KBDW2yj=00}}Neo})v zND?jim$TEQCx*i0fnSlLu=w|T-<&>I)>}LMcQ26y)#(^6DB5$J6kK-rR^(umJtJ;> zde+lDjrDcn;@FO`rb9{L*r8-Szi<_*Q=Xy~)m`<#} zhd}QOKaS7A5Awaj-=L2{8hG1}Q$RDJDoPyV9V<)aHluDWk6A6{{mDPKVq|S5dBO&|{!}js!IS~E^(FviMvnyIr z??RZX1h~<9rjVhB#m@`>WItaL9G-XQe3y@D^TW?S_A>f$*v})GK4+_H62fE+VVv}5 zAlRSeT(l=uzhBWmsQCYN0Uxe7w7P~WD5U3f;qo<35Z37G{wS-hfi#Zo7D&Q`iY%xLV=7OHeBs^yKh~8>futh?Mb=|z-XWx)2OgEN zsZID9OP(Ns>jTePzyJhk;}EQe9-jmI=?Ml6U}3dA+U$ZDgZNC4Ln#%vKNW!|2TQGq368TOclR-~BKjMVpOCNk?e)WY2 zN5BD;fiDaSqSwGn1T~Y4j?ahZ_?Jmubwv`7nA(DC0hF;*^p0_HnVexvv(Ab>Bu@f> z&$I88EPbTn4Srl7NDJ~m^9H+==6kVY@G@}Gl2lrdG2obz2dh#dUVL(&)s>oTsIoLs zr~4iTN0GQDLCXQ4J|Hd6$W6BJ_@nwlY6jDEtTYX$kOWzSC5qk@x3SfVD@0or@F{tv zSHfhc&r^$0Y1Pu{ws+hK0TaImpy0myvw-YzB~57kYEeF!OP{sqOC`1PO?1+S#lj8? zq1k0A0BR}%YLMLv1xcG)IGw96_ND(lnX}r>$L@6n^mrgb%V5{&4!* z0kwlF8mxgNg~^s3srsYBQ`@v8#2Il6gVJmUCc-qK>1aOa6M1G$cHSZSMYG3&acG{M za#d+x7ij?&4x&i2sV1L()Ccy;5!8rhrG-^Kg=uWioeI`f+O#D6b6ew7>H$e-z}hwe z@#Z0~#k@^)M;pHl`$!5<4aQRjUOx<@E>3LhX2u24c`edZ9HKCkX~~K^Y|`(1Q|NuP zEUoq4^e0zg^`N4TWWZZyDusoyfI^k|js`?;(XU-V1hZB?^@BDjG#n-Ei}k zs#^qFAUv5>(glf{B^Nk@`+;di;IN65e`;cyET=PN?QiJWL}Av-dAh)V zH7L>81thoc{wE$^#AeFtyOAJxb=cd4tyY>a;9vr!f&GuLVEFk-e%kUvF0LuKn=Y-F3ag}{iY zp*f4wn91AQ_Ly2y=w4(%Er;NX>0%LTIN{BE@8jUcsqh(0D%Zl%UUXU@&ZO z)~!PMnSSfp$_WDp2W@Dm*hO%Lh;}gVEJCWa*+>6?!{xEaESon%56Q8T87q*IW(JPx zZ%QaLL22cK`js$<7?SI6`PUX9YsZKad?bnz%pI>c9o#aIU-6>dnuN7|y1l*GGKU;C z+ohxGdL%PaNC$m1%(w~OruIanNiY-=`Ute+QEZh7&YMjetva-za9hc?=2V_Bcg`d< z5~kVZwN`}~3_qz9XYIWWr;UQsYd;sCRr%#Zty9_g4HYpqLG85Aqce6w5dqs z2j zYJ&CK59;?oA=LU{xt39E%V>f5Id&`-L$=Zfv5b+d4#4ZMpO?*UYM;LGjrGe$&z+Q> zlH%s>2F#j=^IwBj$-5zmS`%HmVRhT3DY#AOye=W$^J(*qhVB^-PhO_A9i1PcfC&(b zYBn$53Gab!09SZtAV(1@*TFI_{j-G~tJ<&k7HU?OKx;Qy<`g_N_4Q1^i+=;Ih!wp){_ zW;17MVebR{(j3#_2etial1z}MYI|0RK$6MQ1xQ_R(xeZH1Zvq74R-Tlg%$KV4RT zf9$WtOtm1bpoiUZ^RW)B0-h;bLzN9e9YL~;dF4$A17|S-Y;klmh%|!Q>_9 z36Wqe0JS}c${(XG(7+xZ%!Q;VmE9*?;df`W+?w2ksT|YMP+)12_gYl%P0*I`6h-XFgGP7cO@a$J1abL4lC1))xm@Ib^Zr0d>aCv!@eApAjH$h0TuTAUpss~w`7+_lq zh!gfO+e82T)2s0?!joXyeK3%SaEKdFQXVAjr$TlGgFRt-pCAx@+NM&$#e^If;;af; z#c<@bNG~x>$w`g$$K<{FdFdb3(mn1~uXf2Dm)h{E{h-F${7-qTETvUssw4ZsO-@+7 z025M!P%0RStaFRx>1R)kW|ysNwqr*}WObh@1fd(r-h#V=>PSOu=wfai*PT}J>Mdw6 z16>6qMBST#_=c{fWF#W(+(TtNnG<-LwDe^+m`jKkGq_#!>N2Dh5kd2*GOp}6JZo!* zg81GoP=V_7)I}<=NjUzIRFnFrT(kD{eCMF9q|nh}unL?6@|EhMOIp!sRa z!N;oY3UWsC4Q0=Cm;^2S1!lx$Rw1oxxbsiSV>NGs`Tqp7~lD-2xdJJ$3CI0vm12a_*P$* z5;P+x5`F(1wq+fFqsi#eS~yEaZ?J-;csY#v;+Hb0l@P@I3Zo++6idtdWQ=-2SlSqN zhRm1GXj{5Fh!f6~Jl{|&s;kR3_}y<>9GdJ#kA47Ul+l&HBI^RHDHyPDeALjLY73>b zIZe|T@3tSUd+BYbHEp@bZk z6+kMn;>f@;{uf$muuN8tk1lxD)*uldMuB)bOq;}}Zp4)q+;-^gh|ek6Dlz^N5`j+b z%`ZIRcVaFyS^*W+kiJ2%o)5ch*;RE;rLEA=NLzdL~bOcP=wr|aAMvhZJWZ7v`{Ef8BMw%r7d=$3fJ3CT6v2dD)xYY%ha zaAhkNA7ucng(Ak^=kpkMKI!q=#aG69%fm%y!qgOy_`=64dmd%wnYcL9|0{N{Y-?IF zF@rvz0!Y-U)Iwh$3`mqy(L|%4H2tg;>^^f}2tAk5j(dD;qxB^^=H;sa@X{NbRU%_ zguRT6oe%(kp=3aG)1g~JT*oeD;MDua@|$3Q zB%iJ^Ne7d4v#trqa@$9?rVxozvn^FClHs42jAqd=(}pz&*j2OVD+5!*=s7gVuxbk= z7do5g{gNliP6e)rZ74*KZc$La4iU42_(FG2lf0h>Y^KArhAG*#bG8ccC_^>YR}1R~ zRZ2p|_8iITRj7t@nbpm<2I31}PH6-qnWVRRibqu_)A2leV({jsiM@w3v3ON}W#|^MN-vH!XSEW389d??on9h<)u{92hLo zq8s(T@qA6oPg-?*78S0=Y=Rg5n?FrCa2Y-c1=#v|tPQC-KQLVAO-eXR_Fp}>Fn}Wu zExKMIHrz#bOF?Ep$GzAZh%+X%cteIH$=3CUA=vOl=uq&bX=r}NpG;|{d~O3jnu|78 z99^80zaoW0+V?#KTtDNAz@F-hlnrD?bm3JW!)tg#u^Q9Xt&)K_N8mRJs33QUvA(-x zZ>+3ND=S0M^y<>X0Olx(`>@xJ0Ke1ys_~~}|2yo(CWeU~5{2Y8Xk3DnI#(`%CoJjmSFOl+#NfPZ z^Di(EPAt7)?<>@4N#b{hq*|-LT4vmS=B(jHCBL%POPt@1ZmNxv33WrH0Sc1y6y+(dx8~EY4OIc%q{WGSQWtC#H6~#tUJNJ zqr2ZXRs;Q69BnCd$d!!EO)eZ?YX7pz`CU-Zv(N7~3Bqh_6*&@Zy1LYG2wipQ#8rQ` zvj=?Ll(brqg(7-m$jij~uFH9p#d-puGY4bocZX#{03TM3O;(TJW~W~EO)p8#iLu;N z=EhoQ%hI9J**u%zYTk^{8U zu&mXtIa9O(d`mCZXl3}ML&sTM`t*lQ`LELJJ(e2>+yRB8!j=VSd8Xv;>6G`=vq+ts zC2qQZ`afavg3G6?%XkHGh9nCS7&*IiD~3mZo{=%XCAkD}as<;#aP?%;Cm&)5-1$Q9 zb->t^)O=f&Gg*DNtGttAVR?E}T*@!m$$jTDlM4z>NQ`g^o3!6r(mG(Xzi#<9vzJ77_ zW_|LV4VV`_51(V{_c-|ck9WpF1?tPB0y*clx=2$;OqmVF1&%=`oGt%M-g?sNwP9e9 zPz?^g|7_dpr-=&=NU)Y7Nd9TFFEX{e@r{J)K5=pCR-oQL*_Iw2?PiSI_9Arh9YOrY zHVQ0h1`Da;2_KRVOe2IO^p+_bj-rv9GNz6Uf@=(>gn+YD78Raoe~5_a*>L+;RN9gi zVtzYjsh?Z6&e{z;ZB5K7jgD*c#l0lQ1Ek^+9h;k#Un-e>QfH;*NweHFyl2U|w9}IER(c+JYg*|Hk6+^&oI`;M9>jPV%s1woN#NWRmvrg2P{}e z&72PfWtqsM7Gf6;<>l^jV#LMhi|R3(G`$$fg|OoH@$VteO8n3)Mo5hyJwJqzE;Z6R zo5LhG%~WZ87swtNK|PoMcxsxqs*g~s*_`jdN|*fZ4N1PME$^6*j$~J#xIxtxhc2Po->c(Rde+q3$I0l7{u5Ue4lx` zwB-usdtr?#F5Ez)a@mlf6T?_ebSr<4id1*nF3C*?j^JXn#-&|=8abcE)i51i>^An(S`0U<+G zdhWwWE_HwOuw%m#dy5BGnynGUfPb9GVw9RN!QxIhlKYO_`ssZS0Yl-3a7J(XL{y%4 zTv!Mh3z|6sW8w{-x`#&jHJO=WcWr>>UL zLJY$*TrH0{Syy-0^mF@O!9BM}ZJ0^o!0T%E=DdTquiju2N1Wvl`1R>)pn68$9n%08 zrx!XzOAgRP7k29_tMeZ%jcKiFARV>^RzKlf&*M*x0Q_k4!W>}oru7ptz=juy>&u9f z22fpk@+^5Gh|&j0rE}ZY-rYh8OlyZza}5V zy_5H+uTpuZU{y7`ASej>i%Aq|K-?3G98HQ(CnT)3uw!9RNERT)75rY=Y9Dzs1XIic z3oQ!tb9FRtyS%|X|q=w^|Xiyw35#k>3>GCsfS+`kO zjGDKEyrA(eK3=;rXRLRI8DxB#4X5JS1-}WMBxa(XP8`&#YISYKEhr-fv!jBpE;jJ2 z{6~UVTIcbv>)6OgKo(ne`zPZhGAl(>rixIsq8lJlk4_iTvJiwOD#$f2a_j*%aY&yY zlpe+!h9)QmQYt*2Q-Q?X0zAio0d@gz=%;D2&mq5Y0XZz>lYN(Nmdqvq!gEK^6>!i) z(vfKf5kA;=PT)j$V>7D@QX6o4hBDreFy$epuHE(@F(vQi9|0Z>?anZr0(P(yt%v|u zjh_7Q3BlTK&=8UVTPSB_&Vrda5Ym!s0iJyiBMm^~MD~7H{tNc%Kmr8`L`A7fr4Z70 zqdG6D-LCp%#u8t1iZjy{W0^D>*cpo96;NT`dO&Gpj%+!8T@W5{$9% z!^t~=znj>&5V=HfA0Z4;m>YzGm3b_Is{}htuJ%6D?d)X-7C?t8eTow_(5|2KGu8VX zAnkc`!uhT@wG{Hp;mDD!6Ie(XlTqg$x~JSno4Hy_H2x5quF3=nA8^7rP$m-csOG^? zFe)FI(e&SV{eCjjMxGG;&!bM5&R)U55!<#2MoJ%;s}?=$?9gBpd_ftG466YSbY`HQ zfG~qfpRWK2_37i5qNYPY1~?Nf!*gBM_~!Iuq)Kj{dJ&QyMBkLN(%c9`2X0TPPW{-# z_}K@tYbQ(>&JVduWiV0$)VN4E1hS_J(5)P{SwI(D*lsr6W7UOOlwS-K6vZFa<^TK@~XDm2ESNha4!(V-wWCl|@Sv-)V60B0){1ajPY!&(RV zrax}H1itwM?Rf~duZBmJD$9O&KbT9P^sGkjrHVJo3g_uE2WWZeL+e40bY>}H8bM@1 zt7-YVr3Z}!hvb?N*213(3y52YPJa|^B05!cuTpJ~t^Em?K`B5jSr;Vh^q{eb1C#}) zA@qA&uzwH=T2(z0A2K1rUQ%q~`GO58jLtb4|O%A%{6jCWO4DuX~h5a{s zJQnhL67E>#+`I*(OoeO;%O^m6495Xl0;R!ZZZpo1VMDKkckk>@Ih;S+drBX1sD^UD zDxB0~Ue-cf|6wxCCu)DfU=wB%t_ZaO07C@pOxa<4wW1h;;nYG^OIS=2-8D?fLvI%E zv-dH}26!)ZyH2k=#ZCwG1AxiW?7+aQVci?^onWGzv*jAWF$rmTWCNs_m=Q$1So4By z=o*Q(ES_i-2O95aZtuHGv|X4SV#l`o4xaeNi1aBapxmb3vBAk=Meva{QKpr!C4*Cy- zZRV4NhCA@FQY@+C*k8{-22RUk!pK6c1DkuG69^1X?5F}l+-CtjTeQFwmf?}Y*rcC};6(Y6QPsTUe+li&S50~vII6^nYG z9g|Iki_n@-kQM?c-G?^BiLQZlo=jv#`@bV`#yliz!I>&4pq?CAfHiEJ#MF5q(4cuu zn`Fqg_xPs^BqGU)QHTz>GMJ^lO-Hn?D4mGMVI^_K+-EHUtZ24=ksF*ggM1fCTlhp6@${hJ#1olJIJ>Wh5B&y3pGVUSrri|#fYwb?I<6klvz ztB7Ak)<&Voj59=J)AFj<6CMTP$nhw5z1x%<0_MQ7=C>ErC>2tDI@Ess!XJ?+Z}d#Y z4s&uulpBPsOX^3lbdkqp$xS113zry8_@1N3`&*Tj$_vfaJrl>HuV)D@k1V#;I{ygx2n^|%L zMh<2(BUClbq9{-od*IW4oQq{&IN8MaAwcUTko$ufK zmTvZ6`MFlBv-zEtuqUs5_wOl>|0_TA9HjZe;PCKI8>L0tR&2NHBeGL-{!w&HYgDyJ z^iDTjBvKoy-tQCzYO@EMTU#rInQI$$aOx%1~RrFz)SN_)IntrRE9);NW{H%{6c{RmgqySF z2FuI8Z*H$%1$jn#>oBL#tgNgdLx(1Q_^@v0@4qXL9Xk&*xbM1L(H6OCsux^+-2Not z^l39>V-|x;86qQ-irBmXh-PEGm91^^$B*j~EUpd?4K<%Vd)w&IqoWQVP6g1eZnT5t zPgIo(!DDJkmLgZ>qA8yoO~hs!>8D@#$=gfxJT%mE;>)_{AzfGOan9CEnluSrxy~vi z#l`)CPb!B(@*|M?5;r$DKL9fFqZ6$je4Jo^$^5{91N{dKSgfo2TXRdxReK4MPv48> zK4^$>7apqO-0CF*yQ$eC$2XXYXD@;>s) zHq&wEQ=qiI{rW9klqR}6WUt4&>xX4Thv%ki$YFKNrmGjs+#k8aR9&=fY~rzF*XFAi z7Z~M z033<$*?yuAW3ewdD!BaKz1#Ko5|2DYDcM?7o%gSO^=c)elM#tqryjbkQI}E;a1pA| z0@Ca9o0R0_OJ{HNd9FUn>G4hbXZBzzoD>eErmE^M@%6n}|LA@{5P#G2N3qUTOsU0oVmJJu&8B-A73${#Yg>`THM!Sh!$8s2u7|5jB(~J;Q5~py^XlP|#hze0kP`RkoKmu1<4W zv2vvc$cFxSIyM^9(IKdVAocW!FR~4_G1y+KR*az$*@$aJb>(3%wv3m6iisl75WAydx1HOtmvy3J-9`IahQRvz@6IuT^HhK zV#{Fy&<^#$trboH^l zu$c0tbk9nI2uiJUR@aBQSMlDU(f0l_ckYOVlSFM&Uc0f+){XPpi60m2im-ln zZ~1mEj2DTlhWg2LnqvP-8>~S1A9fC4djIq1%OWEqYb!p%_*~xWjGOz)HN{uM$IkAs z7H;<9p7QcVg&yk7&CQ{e?o*iLZb#->cB`1&3|A;W)^6Xi<2J|+nYYM!Pl~vOcvfb0 zWo6|GDA?vj+^oGnV$qpNuN;aOnwhP;UYa=diedlRM!M-0_Gc#Tz8O9I#fujwAxa)Q za^$tBsHh6%8g1-7;acLfElYmk?aaHKHJC|z8n_MQ64cL#Sj z-Psg_`odubFCjTNwIRUO^%M^2d%eQqJ9qHPRjX2O=4555-^=!}BEFyg&?kDxwwK%N zQZBD>+8_C!ZNHDUYDj+|Y_Lg1`!9J+{as7z?2j|h_sX?&jd=7>6(*5r@qLYlw*b5CYgFWa*hWGTG`etCO>XhX1^58}8it zn!Fzvw={qHR1K$?g70bXbwu=hf^+CA9NHPaiz0?dNm({_pFVv$BnSGsFkUJ|FXHUh&W#4; zaGDw`@5sdX;}2`hYTr-a^sXY>YH1Rp#&_Fm-#<8`9|1S~DywwG8H0$8@1MS{*o4+V zYFob>IJ~aeZ};z2ss=E6xy!DKwiS%B8vYVVO|V;E;RNk z2vbeKPF5qrQ82*bM^t1w>C>ytnKLKiL|vUTpu{g=5AVU5e1uh-n!G~vK}!8(eM)+I z)!Km0C2yZzT~a!uB^05HRO06m)7Dggr_c6ja@_FdewmTQH?!}}PVXanyz3u>bTV8^ zOJfMeWL^uD96Hn}0*dpI2DC>CZZArrZ7+Jg&m0ijS=MDuPJa=%` zl%CJ+6O|Q<-Ltjb)5`81>?4Z#>0EuZ!ea(vh_qKYjY6t3&&U@IyAoB{a!9MO;oU9SwaPKH5#0rlcPxIHU(;lRqO z&==3()wqKJ$tT}}p^dRs2hx$QwE;aDwf<3iv;0oj=E@eqkNjkX#d2HSW_UcKN9 zTyP@ut6M6dn@WWxa)XZX5@fjw26(q91*H4ZmO7z^Z8ooG#xGaZU^pWcRzJwfcX3Ab$T!tr#>Tg>u+Yk$a(0F2+KV~q8n+OWPeDG@9@Gh(Ww=^>$VU_bC>j7R zg`I;7HCVJ}O}bIf!Kn^gCj78%-|o<+RQT_krKP4wpRIdv-v6_oL3XZtBl0+w^XB~y z);SSj`Shiot&98g=>z;`n6?B)Z~OM`8ZIS6j1Oy;eID9RWc=-yd%0Y=mt&k}0D1Tv z?jzFB>f5*Pya)pmlZ_9oEG%5mi^R3-^Jn8_%Z|ZfZhUj*(c6;&J>3P*o;{;RTdioiNy+SU69k2BHikkMXW1mWHF zZ@>HQ5UNbH9tQp>?Wk}7oUG=<ASXGw{+NM2mD>2`s}3@gxh_JnQe7heuvY-hLSdJf3^-S@I9?q|7#{P5$C z8UQ;d&z`MtbJ()wN~X2yOL+W^XDuu(4}x6ZyIgcoM(`FFyPnxETD*9-0+8aL6|+{{ z8aAfb*~LZcVad&##!#LcuUmK8$e*ZYUc^PP8OFGoJ3OWo>@O;se-9qcNLxyDu)kLt z9M5tSG12f~yQD9Eq+p~|uOGRqZXN3OP3eWcBF*uts^(y(hYTJZk5*(x+MgZCVD9MXc(tcTv@K(l)#DfyUtiyb>Wo-r6_qL2jvG+f zlwb-!1;07bzz;kWi{u`xn)(@Rq#1zWU+LDAH(Ta8Z~N;{ML_q0h86N$BBIUOV>Vmua=aQWN2h`kI>1#%SV4R;uYkSPT+bK z$Bj!t-jhZTGI$mCTNL7zwKzm5S4=&2Y$O}u!f9+w@_=g{x^=q=8YcL6HR%I$ZyI{l zMC6)WGzG}>+L0nj{CMs)3Wy=)`eC~PFW*`u0!mHEz13bX1x$kc$r_+U7xZv?{&B#8 zSQRz3shB^gL8RpNje4lnHtZzGbK zua2WVwCCsB?GCjS*$F34R(Z&XK3LQRcY3{G0(RNS*DG`1biwo80@-^Dc1z z^aJj*EblffQ{#-%?$4dK@i${c3JXwCfMab(b)*y4rZA-Hd(lzd0-sNpjTSA6_@RIQ z8@IKRgk7&K2JRt;3^>Hc! z%U%VAhsWeQdd|&`sRq58jOpsSFb@_u6`&9hv|hEB5@~qU-Sx%0`AA8xm&5xX)fdxx zbj9}!V-jd+m0z^y>mWH)me+FR<4Kl?;o~X zZl-O7xxIZFo_Q5HZ`)k|tLW;dhhuQ_a$mn|(3?abH~?h&v!1T5l8(-^M~Z!Bhs+V% zJC`VN5I+5>u9yLoV)pGk} z-%kF@n5`DDIUZs&&1}%kMj#(3xVOzSu7sb*khE2L`uZx9Cog8T<`obeXNT1d(!yu{ z(r@n^F6o|+g*x2}M~xbV;K{S&yW5xu;(^i~LsQdgBGjwj^f@LyCa`D}F}i z*DMS{okPP)7IMBk!=Ype?ll_`zcZxzgziE?Pb{Co8siYsJQ^7-~A+K4yCjG(PiEc5g0_Ewyf%uw@5VncNy(K-m)<}?u^>; z+<>+GD_qgGN_oPBN@PdgSNMN;K4&J{di4MPdovunMT-_yT;8-6-6zc;_q>OybeXM- z7cZ8SmacE}zFn1zh_f6#&jvIAP(skb;1XC*FvxH2FGT#BYGe#sF&u=D0_#^VQH!>B z?{VaTE`HoCY|DkqmmfbgOI);V9B#j8?Wi3h(bI|uuOjzsWo>O+Sb4S2)u}eM!+sJy z*Q&d##Vsl)yHd&QbF4;I)jWLK?W1~e+>R#wyBPGA3ekI|8cTZpn`-)_A8 z1Udd2)~Zt8D4Vk_U1DZtmQh-hsIIOas=)qd{ZNmCnCa9;d@DPDXS26OBSjy6^m=^- zq0#SN|dNh#Y}dS*txL(4|29{ zzFH^}RmlI>bG`Jvm}r)omX^)S$D{p^AbU;%5o!aCfbK0%kfKXYPv6*(i=ct06zMhu z6R7wXf1-|DjO)O6*2)+NPsQz)!7NaZ})#ds(R5t zViOgO)EMah6obN^&$3LhKzv&UZPanZm5Xd_Z2XWR+<;29`4Pwlo`A$uNl9rw^-?&c zLxv1dFhGVqW<4B{-;u=B)CJ}3hvGWEA27h%t)*r9(79l8(mrX5ww(<5*Rk6_7k={| z!(HU>UPFezq06ZO;Dsb+>+7<)s34C*A&*z_D=hF&TDoH7T;@0rxG|lkE z9%ecP)}Cbh8gQ?RH*P$Kj`{VkJwnSU8So3Y_ujF-?ZfgJAWL#hjBPxmB6>h}Rp@1Y9AtwDz*q$`>ke0uRKQimfF~J` zuLtJ^c(!48hkgn4obHGlm6ept%*`wGI@dnP(lgqpx>w=tvs{OWDELLZpJ9Z=;K9-R z_xlvLibM_b{&sNiQWfRi)L~$x+e!M=!VRcSG{7hjmMK)6G#>F8%_DN^Y(=7>|GezC zu{T7bEsN&N`E_qyR$m~&`}eIUIVrP=&>U-5fn;Ong;o+zKOO-CJw zPu8F;&;55qaWvq7;0ySTK%(A8BaMbB{Y1x>7aeb*?^RWQJwyF_<6rZKIjMPGRqDlC zfB%0*Uc=jeCXx3(yv~y#>A(E>(C7a=)Vm=6?Mn1~w@6d*f!_XaeE1{-JrmvZUzTnE FzW~1M8 Date: Sat, 1 Nov 2025 20:01:15 +0100 Subject: [PATCH 19/21] Remove extra trailing comma Signed-off-by: simondanielsson --- vllm/model_executor/models/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 8272ee9ae118..7150977e9266 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -300,7 +300,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Warning: Prefix caching is currently enabled. " "Its support for Mamba2 layers is experimental. " - "Please report any issues you may observe.", + "Please report any issues you may observe." ) else: logger.info( From dbb4fe36e69f6b334d4eca492ad06db207b92c1e Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Sat, 1 Nov 2025 20:05:33 +0100 Subject: [PATCH 20/21] Move hardcoded chunk size to GDN attn metadata builder Signed-off-by: simondanielsson --- vllm/config/model.py | 4 ---- vllm/v1/attention/backends/gdn_attn.py | 4 +++- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 31a670164324..9b6df80e529f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1483,10 +1483,6 @@ def get_mamba_chunk_size(self) -> int | None: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) - if chunk_size is None and self.hf_text_config.model_type == "qwen3_next": - # Fallback for Qwen3-Next. 64 is a hardcoded value in the GDN kernel. - # https://github.com/fla-org/flash-linear-attention/blob/2e7336262c11f8bc6cd6a94b1eb5ee353ae8b4cd/fla/ops/common/chunk_delta_h.py#L439 - return 64 return chunk_size diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 3acb1088b367..754ef3828c80 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -101,7 +101,9 @@ def __init__( self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) - self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() + # 64 is a hardcoded value in the FLA GDN kernel. + # https://github.com/fla-org/flash-linear-attention/blob/2e7336262c11f8bc6cd6a94b1eb5ee353ae8b4cd/fla/ops/common/chunk_delta_h.py#L439 + self.chunk_size = 64 if self.vllm_config.cache_config.enable_prefix_caching and ( kv_cache_spec.block_size % self.chunk_size != 0 ): From efd451b32e1b7a11ea726d3f9f2ccdf9b14bee6d Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Sat, 1 Nov 2025 20:06:29 +0100 Subject: [PATCH 21/21] Remove extra newline Signed-off-by: simondanielsson --- vllm/config/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 9b6df80e529f..082f90653f5a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1483,7 +1483,6 @@ def get_mamba_chunk_size(self) -> int | None: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) - return chunk_size def get_multimodal_config(self) -> MultiModalConfig: