[Relax][PyTorch] Add support for bidirectional LSTM#18516
[Relax][PyTorch] Add support for bidirectional LSTM#18516mshr-h merged 3 commits intoapache:mainfrom
Conversation
|
cc @tlopex |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for bidirectional LSTMs in the PyTorch frontend, which is a great enhancement. The implementation involves refactoring the LSTM logic into a new _lstm_cell_unroll helper function and updating the _lstm operator to handle both forward and backward passes. The accompanying tests have also been significantly improved by adding a numerical verification helper and expanding test coverage to include various LSTM configurations. I've identified a critical regression where the fallback for missing LSTM parameters was removed, which could lead to crashes. I've also suggested adding a warning for when a default hidden_size is used, to prevent silent errors. Overall, this is a solid contribution with good testing practices.
| weight_ih_fwd = params[0] if params else None | ||
| weight_hh_fwd = params[1] if params and len(params) > 1 else None | ||
| bias_ih_fwd = params[2] if params and has_biases and len(params) > 2 else None | ||
| bias_hh_fwd = params[3] if params and has_biases and len(params) > 3 else None | ||
|
|
||
| if bidirectional and params and len(params) >= params_per_direction * 2: | ||
| weight_ih_bwd = params[params_per_direction] | ||
| weight_hh_bwd = params[params_per_direction + 1] | ||
| bias_ih_bwd = params[params_per_direction + 2] if has_biases else None | ||
| bias_hh_bwd = params[params_per_direction + 3] if has_biases else None | ||
| else: | ||
| # Fallback: create zero weights | ||
| weight_ih = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) | ||
| ) | ||
| weight_hh = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) | ||
| ) | ||
| bias_ih = None | ||
| bias_hh = None | ||
| # Initialize hidden and cell states | ||
| weight_ih_bwd = None | ||
| weight_hh_bwd = None | ||
| bias_ih_bwd = None | ||
| bias_hh_bwd = None |
There was a problem hiding this comment.
This change removes the fallback logic for creating zero-initialized weights when LSTM parameters are not provided. The new implementation assigns None to weight variables, which will cause a crash inside _lstm_cell_unroll when relax.op.permute_dims is called on a None value. This appears to be a regression from the previous behavior.
Please consider restoring the fallback logic to create zero weights for both forward and backward directions if they are not available in params.
| else: | ||
| # Fallback to a default hidden size | ||
| hidden_size = 16 |
There was a problem hiding this comment.
The code falls back to a default hidden_size of 16 when it cannot be inferred from the model parameters. This could lead to unexpected behavior or errors if the actual model has a different hidden size. It would be beneficial to add a warning to notify the user about this fallback, so they are aware of the potential discrepancy.
|
Thanks for your suggestions, I've applied review and updated the PR. |
|
Thanks! |
How