1616
1717from typing import Optional , Tuple
1818from flax import nnx
19+ import jax
1920import jax .numpy as jnp
2021from ... import common_types
2122from ..attention_flax import NNXAttentionOp
@@ -347,6 +348,7 @@ def __init__(
347348 attention_kernel : str = "flash" ,
348349 rope_type : str = "interleaved" ,
349350 flash_block_sizes : BlockSizes = None ,
351+ flash_min_seq_length : int = 4096 ,
350352 ):
351353 self .heads = heads
352354 self .rope_type = rope_type
@@ -434,6 +436,7 @@ def __init__(
434436 axis_names_q = (common_types .BATCH , common_types .SELF_ATTN_HEAD , common_types .SELF_ATTN_Q_LENGTH , common_types .D_KV ),
435437 axis_names_kv = (common_types .BATCH , common_types .SELF_ATTN_HEAD , common_types .SELF_ATTN_KV_LENGTH , common_types .D_KV ),
436438 flash_block_sizes = flash_block_sizes ,
439+ flash_min_seq_length = flash_min_seq_length ,
437440 )
438441
439442 def __call__ (
@@ -447,46 +450,49 @@ def __call__(
447450 # Determine context (Self or Cross)
448451 context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
449452
450- # 1. Project
451- query = self .to_q (hidden_states )
452- key = self .to_k (context )
453- value = self .to_v (context )
453+ # 1. Project and Norm
454+ with jax .named_scope ("QKV Projection" ):
455+ query = self .to_q (hidden_states )
456+ key = self .to_k (context )
457+ value = self .to_v (context )
454458
455- # 2. Norm (Full Inner Dimension)
456- query = self .norm_q (query )
457- key = self .norm_k (key )
459+ with jax . named_scope ( "QKV Norm" ):
460+ query = self .norm_q (query )
461+ key = self .norm_k (key )
458462
459463 # 3. Apply RoPE to tensors of shape [B, S, InnerDim]
460464 # Frequencies are shape [B, S, InnerDim]
461465 # 3. Apply RoPE
462- if rotary_emb is not None :
463- if hasattr (self , "rope_type" ) and self .rope_type == "split" :
464- # Split RoPE: passing full freqs [B, H, S, D//2]
465- # apply_split_rotary_emb handles reshaping query/key
466-
467- query = apply_split_rotary_emb (query , rotary_emb )
468-
469- if k_rotary_emb is not None :
470- key = apply_split_rotary_emb (key , k_rotary_emb )
471- elif encoder_hidden_states is None :
472- key = apply_split_rotary_emb (key , rotary_emb )
473-
474- else :
475- # Interleaved (Default)
476- query = apply_rotary_emb (query , rotary_emb )
477- if k_rotary_emb is not None :
478- key = apply_rotary_emb (key , k_rotary_emb )
479- elif encoder_hidden_states is None :
480- key = apply_rotary_emb (key , rotary_emb )
481-
482- # 4. Attention
483- # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
484- attn_output = self .attention_op .apply_attention (query = query , key = key , value = value , attention_mask = attention_mask )
485-
486- # 7. Output Projection
487- hidden_states = self .to_out (attn_output )
488-
489- if self .dropout_layer is not None :
490- hidden_states = self .dropout_layer (hidden_states )
466+ with jax .named_scope ("Apply RoPE" ):
467+ if rotary_emb is not None :
468+ if hasattr (self , "rope_type" ) and self .rope_type == "split" :
469+ # Split RoPE: passing full freqs [B, H, S, D//2]
470+ # apply_split_rotary_emb handles reshaping query/key
471+
472+ query = apply_split_rotary_emb (query , rotary_emb )
473+
474+ if k_rotary_emb is not None :
475+ key = apply_split_rotary_emb (key , k_rotary_emb )
476+ elif encoder_hidden_states is None :
477+ key = apply_split_rotary_emb (key , rotary_emb )
478+
479+ else :
480+ # Interleaved (Default)
481+ query = apply_rotary_emb (query , rotary_emb )
482+ if k_rotary_emb is not None :
483+ key = apply_rotary_emb (key , k_rotary_emb )
484+ elif encoder_hidden_states is None :
485+ key = apply_rotary_emb (key , rotary_emb )
486+
487+ with jax .named_scope ("Attention and Output Project" ):
488+ # 4. Attention
489+ # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
490+ attn_output = self .attention_op .apply_attention (query = query , key = key , value = value , attention_mask = attention_mask )
491+
492+ # 7. Output Projection
493+ hidden_states = self .to_out (attn_output )
494+
495+ if self .dropout_layer is not None :
496+ hidden_states = self .dropout_layer (hidden_states )
491497
492498 return hidden_states
0 commit comments