Skip to content

[Relax][PyTorch] Add support for bidirectional LSTM#18516

Merged
mshr-h merged 3 commits intoapache:mainfrom
guan404ming:bidirectional-lstm
Nov 28, 2025
Merged

[Relax][PyTorch] Add support for bidirectional LSTM#18516
mshr-h merged 3 commits intoapache:mainfrom
guan404ming:bidirectional-lstm

Conversation

@guan404ming
Copy link
Member

@guan404ming guan404ming commented Nov 27, 2025

How

  • implement bidirectional LSTM

@guan404ming guan404ming marked this pull request as ready for review November 27, 2025 08:37
@guan404ming
Copy link
Member Author

cc @tlopex

@mshr-h
Copy link
Contributor

mshr-h commented Nov 27, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for bidirectional LSTMs in the PyTorch frontend, which is a great enhancement. The implementation involves refactoring the LSTM logic into a new _lstm_cell_unroll helper function and updating the _lstm operator to handle both forward and backward passes. The accompanying tests have also been significantly improved by adding a numerical verification helper and expanding test coverage to include various LSTM configurations. I've identified a critical regression where the fallback for missing LSTM parameters was removed, which could lead to crashes. I've also suggested adding a warning for when a default hidden_size is used, to prevent silent errors. Overall, this is a solid contribution with good testing practices.

Comment on lines +467 to +481
weight_ih_fwd = params[0] if params else None
weight_hh_fwd = params[1] if params and len(params) > 1 else None
bias_ih_fwd = params[2] if params and has_biases and len(params) > 2 else None
bias_hh_fwd = params[3] if params and has_biases and len(params) > 3 else None

if bidirectional and params and len(params) >= params_per_direction * 2:
weight_ih_bwd = params[params_per_direction]
weight_hh_bwd = params[params_per_direction + 1]
bias_ih_bwd = params[params_per_direction + 2] if has_biases else None
bias_hh_bwd = params[params_per_direction + 3] if has_biases else None
else:
# Fallback: create zero weights
weight_ih = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype)
)
weight_hh = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype)
)
bias_ih = None
bias_hh = None
# Initialize hidden and cell states
weight_ih_bwd = None
weight_hh_bwd = None
bias_ih_bwd = None
bias_hh_bwd = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change removes the fallback logic for creating zero-initialized weights when LSTM parameters are not provided. The new implementation assigns None to weight variables, which will cause a crash inside _lstm_cell_unroll when relax.op.permute_dims is called on a None value. This appears to be a regression from the previous behavior.

Please consider restoring the fallback logic to create zero weights for both forward and backward directions if they are not available in params.

Comment on lines 453 to 455
else:
# Fallback to a default hidden size
hidden_size = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code falls back to a default hidden_size of 16 when it cannot be inferred from the model parameters. This could lead to unexpected behavior or errors if the actual model has a different hidden size. It would be beneficial to add a warning to notify the user about this fallback, so they are aware of the potential discrepancy.

@guan404ming guan404ming requested a review from mshr-h November 28, 2025 02:49
@guan404ming
Copy link
Member Author

Thanks for your suggestions, I've applied review and updated the PR.

Copy link
Contributor

@mshr-h mshr-h left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@guan404ming
Copy link
Member Author

Thanks!

@mshr-h mshr-h merged commit 1c77db7 into apache:main Nov 28, 2025
13 checks passed
@guan404ming guan404ming deleted the bidirectional-lstm branch November 28, 2025 10:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants