Skip to content

[Relax] Fix wrong memory planning when only lower bound was provided#18663

Merged
tlopex merged 1 commit intoapache:mainfrom
mshr-h:fix-memory-planning-issue
Jan 16, 2026
Merged

[Relax] Fix wrong memory planning when only lower bound was provided#18663
tlopex merged 1 commit intoapache:mainfrom
mshr-h:fix-memory-planning-issue

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Jan 15, 2026

This PR fixes an issue in StaticPlanBlockMemory where dynamic shapes were incorrectly planned as static memory when only a lower bound was provided for TIR variables.

Repro:

repro_dynamic_memory_plan.py
import tvm
from tvm import relax, testing
from tvm.relax.frontend.torch import from_exported_program
from torch.export import Dim, export
import torch


class SimpleConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x)


def main():
    model = SimpleConv().eval()

    example = torch.randn(2, 3, 32, 32)
    batch = Dim("batch")  # No max= specified, so upper bound is unknown
    exported = export(model, (example,), dynamic_shapes={"x": {0: batch}})

    mod = from_exported_program(exported)
    mod = relax.transform.DecomposeOpsForInference()(mod)

    target = tvm.target.Target("llvm")
    exe = tvm.compile(mod, target=target)

    vm = relax.VirtualMachine(exe, tvm.cpu())
    inp = tvm.runtime.from_dlpack(example)
    out = vm["main"](inp)

    expected = model(example).detach().numpy()
    actual = out[0].numpy()
    testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
    main()

This will fail with the following error.

output
$ uv run python repro_dynamic_memory_plan.py 
/home/ubuntu/data/project/tvm-example/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:182: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:119.)
  return torch._C._cuda_getDeviceCount() > 0
Traceback (most recent call last):
  File "/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py", line 40, in <module>
    main()
  File "/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py", line 32, in main
    out = vm["main"](inp)
          ^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", line 549, in tvm::runtime::vm::VirtualMachineImpl::InvokeClosurePacked(tvm::ffi::ObjectRef const&, tvm::ffi::PackedArgs, tvm::ffi::Any*)
    clo->impl.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv);

  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", line 622, in operator()
    *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, inputs);

  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", line 693, in tvm::runtime::vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::ffi::Any, std::allocator<tvm::ffi::Any> > const&)
    RunLoop();
  
  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", line 816, in tvm::runtime::vm::VirtualMachineImpl::RunLoop()
    this->RunInstrCall(curr_frame, instr);

  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", line 767, in tvm::runtime::vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::vm::VMFrame*, tvm::runtime::vm::Instruction)
    this->InvokeClosurePacked(func_pool_[instr.func_idx].cast<ObjectRef>(), args, &ret);

  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/builtin.cc", line 405, in operator()
    *rv = sobj->AllocTensor(offset, shape, dtype);

  File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/memory/memory_manager.cc", line 98, in tvm::runtime::memory::StorageObj::AllocTensor(long, tvm::ffi::Shape, DLDataType)
    ICHECK(offset + needed_size <= this->buffer.size)
  
  File "/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line 321, in tvm::runtime::detail::LogFatal::~LogFatal()
    GetEntry().Finalize();

  File "/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line 337, in tvm::runtime::detail::LogFatal::Entry::Finalize()
    InternalError error(file_, lineno_, stream_.str());

tvm.error.InternalError: Check failed: (offset + needed_size <= this->buffer.size) is false: storage allocation failure, attempted to allocate 524288 at offset 0 in region that is 262144bytes

@mshr-h mshr-h force-pushed the fix-memory-planning-issue branch from 7666f8b to 31898c6 Compare January 15, 2026 12:22
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mshr-h, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue in the StaticPlanBlockMemory transformation where dynamic tensor shapes, characterized by TIR variables with only a lower bound, were erroneously treated as having a static memory footprint. The fix refines the logic to ensure that static memory allocation is only attempted when a definitive upper bound for a TIR variable is known, thereby preventing potential memory miscalculations and allowing for proper dynamic memory management in other scenarios.

Highlights

  • Corrected Memory Planning Logic: The StaticPlanBlockMemory pass now correctly handles dynamic shapes by only binding TIR variables to a static range if an explicit upper bound is provided. Previously, the presence of only a lower bound could lead to incorrect static memory planning.
  • Improved TIR Variable Handling: When only a lower bound (and it's non-negative) is available for a TIR variable, the variable is now marked as non-negative instead of being bound to an unbounded range. This ensures that memory planning remains dynamic for such cases.
  • Enhanced Test Coverage: New test cases have been added to specifically verify the behavior when only a lower bound is provided (ensuring dynamic planning) and when both upper and lower bounds are present (ensuring static planning).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an issue in StaticPlanBlockMemory where dynamic shapes were incorrectly planned as static memory when only a lower bound was provided. The change in SetTIRVarRangeConstraints ensures that memory is planned statically only when an upper bound is available. The accompanying tests are comprehensive, covering cases with only a lower bound, both upper and lower bounds, and invalid bounds. I have one suggestion to improve the maintainability of the newly added tests by reducing code duplication.

@mshr-h mshr-h marked this pull request as ready for review January 15, 2026 12:36
@mshr-h mshr-h force-pushed the fix-memory-planning-issue branch from 31898c6 to 6bf0dca Compare January 15, 2026 12:39
@mshr-h mshr-h changed the title [Relax] Fix wrong memory planning when only lower bound provided [Relax] Fix wrong memory planning when only lower bound was provided Jan 15, 2026
Copy link
Member

@guan404ming guan404ming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@tlopex tlopex merged commit c866abc into apache:main Jan 16, 2026
15 checks passed
@mshr-h mshr-h deleted the fix-memory-planning-issue branch January 25, 2026 10:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants