Skip to content

Commit ceca471

Browse files
Merge pull request #373 from AI-Hypercomputer:prisha/ltx2_fixes
PiperOrigin-RevId: 897831150
2 parents f04046c + c5bb862 commit ceca471

File tree

7 files changed

+261
-179
lines changed

7 files changed

+261
-179
lines changed

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
5+
a2v_attention_kernel: 'flash'
6+
v2a_attention_kernel: 'dot_product'
57
attention_sharding_uniform: True
68
precision: 'bf16'
79
scan_layers: True
@@ -68,6 +70,7 @@ flash_block_sizes: {
6870
block_kv_dkv_compute: 2048,
6971
use_fused_bwd_kernel: True,
7072
}
73+
flash_min_seq_length: 4096
7174
dcn_context_parallelism: 1
7275
dcn_tensor_parallelism: 1
7376
ici_data_parallelism: 1

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19+
import jax
1920
import jax.numpy as jnp
2021
from ... import common_types
2122
from ..attention_flax import NNXAttentionOp
@@ -347,6 +348,7 @@ def __init__(
347348
attention_kernel: str = "flash",
348349
rope_type: str = "interleaved",
349350
flash_block_sizes: BlockSizes = None,
351+
flash_min_seq_length: int = 4096,
350352
):
351353
self.heads = heads
352354
self.rope_type = rope_type
@@ -434,6 +436,7 @@ def __init__(
434436
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
435437
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
436438
flash_block_sizes=flash_block_sizes,
439+
flash_min_seq_length=flash_min_seq_length,
437440
)
438441

439442
def __call__(
@@ -447,46 +450,49 @@ def __call__(
447450
# Determine context (Self or Cross)
448451
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
449452

450-
# 1. Project
451-
query = self.to_q(hidden_states)
452-
key = self.to_k(context)
453-
value = self.to_v(context)
453+
# 1. Project and Norm
454+
with jax.named_scope("QKV Projection"):
455+
query = self.to_q(hidden_states)
456+
key = self.to_k(context)
457+
value = self.to_v(context)
454458

455-
# 2. Norm (Full Inner Dimension)
456-
query = self.norm_q(query)
457-
key = self.norm_k(key)
459+
with jax.named_scope("QKV Norm"):
460+
query = self.norm_q(query)
461+
key = self.norm_k(key)
458462

459463
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
460464
# Frequencies are shape [B, S, InnerDim]
461465
# 3. Apply RoPE
462-
if rotary_emb is not None:
463-
if hasattr(self, "rope_type") and self.rope_type == "split":
464-
# Split RoPE: passing full freqs [B, H, S, D//2]
465-
# apply_split_rotary_emb handles reshaping query/key
466-
467-
query = apply_split_rotary_emb(query, rotary_emb)
468-
469-
if k_rotary_emb is not None:
470-
key = apply_split_rotary_emb(key, k_rotary_emb)
471-
elif encoder_hidden_states is None:
472-
key = apply_split_rotary_emb(key, rotary_emb)
473-
474-
else:
475-
# Interleaved (Default)
476-
query = apply_rotary_emb(query, rotary_emb)
477-
if k_rotary_emb is not None:
478-
key = apply_rotary_emb(key, k_rotary_emb)
479-
elif encoder_hidden_states is None:
480-
key = apply_rotary_emb(key, rotary_emb)
481-
482-
# 4. Attention
483-
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
484-
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
485-
486-
# 7. Output Projection
487-
hidden_states = self.to_out(attn_output)
488-
489-
if self.dropout_layer is not None:
490-
hidden_states = self.dropout_layer(hidden_states)
466+
with jax.named_scope("Apply RoPE"):
467+
if rotary_emb is not None:
468+
if hasattr(self, "rope_type") and self.rope_type == "split":
469+
# Split RoPE: passing full freqs [B, H, S, D//2]
470+
# apply_split_rotary_emb handles reshaping query/key
471+
472+
query = apply_split_rotary_emb(query, rotary_emb)
473+
474+
if k_rotary_emb is not None:
475+
key = apply_split_rotary_emb(key, k_rotary_emb)
476+
elif encoder_hidden_states is None:
477+
key = apply_split_rotary_emb(key, rotary_emb)
478+
479+
else:
480+
# Interleaved (Default)
481+
query = apply_rotary_emb(query, rotary_emb)
482+
if k_rotary_emb is not None:
483+
key = apply_rotary_emb(key, k_rotary_emb)
484+
elif encoder_hidden_states is None:
485+
key = apply_rotary_emb(key, rotary_emb)
486+
487+
with jax.named_scope("Attention and Output Project"):
488+
# 4. Attention
489+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
490+
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
491+
492+
# 7. Output Projection
493+
hidden_states = self.to_out(attn_output)
494+
495+
if self.dropout_layer is not None:
496+
hidden_states = self.dropout_layer(hidden_states)
491497

492498
return hidden_states

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def __call__(
108108
Returns:
109109
(video_embeds, audio_embeds, new_attention_mask)
110110
"""
111-
# 1. Shared Feature Extraction
112-
features = self.feature_extractor(hidden_states, attention_mask)
111+
with jax.named_scope("Text Encoder Forward"):
112+
# 1. Shared Feature Extraction
113+
features = self.feature_extractor(hidden_states, attention_mask)
113114

114-
# 2. Parallel Connection
115-
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
116-
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
115+
# 2. Parallel Connection
116+
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
117+
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
117118

118-
return video_embeds, audio_embeds, new_attention_mask
119+
return video_embeds, audio_embeds, new_attention_mask

0 commit comments

Comments
 (0)