Skip to content

KDA Backward pass optimizations #14

@icavan

Description

@icavan

Description

Optimize the backward pass kernels for all supported linear attention variants to improve training throughput.

Tasks

  • Profile backward pass performance and identify bottlenecks
  • Implement cuda KDA bwd subchunk intra and wy_dqkg
  • Benchmark against FLA Triton backward pass
  • Validate gradient correctness

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions