Skip to content

Commit 03fd809

Browse files
lingvo-botcopybara-github
authored andcommittedJan 21, 2025·
Make _GetSourceAndQuerySegmentIds return the correct shaped zero tensor of query_segment_id when inputs.query_vec has different shape.
PiperOrigin-RevId: 718045538
1 parent a4fbd33 commit 03fd809

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed
 

‎lingvo/core/attention.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,8 @@ def _GetSourceAndQuerySegmentIds(
492492
Returns:
493493
A tuple of 2 elements.
494494
495-
- The source segment id tensor: [time, batch_size].
496-
- The query segment id tensor: [batch_size].
495+
- The source segment id tensor: [time, source_batch].
496+
- The query segment id tensor: [target_batch].
497497
"""
498498
p = self.params
499499
if p.packed_input:
@@ -518,9 +518,13 @@ def _GetSourceAndQuerySegmentIds(
518518
' a default all-zero value instead, packed_input will be'
519519
' ineffective.'
520520
)
521-
if source_padding is not None:
521+
if source_padding is not None and inputs.query_vec is not None:
522+
# query_vec.shape could be different from [target_batch, query_dim]
523+
# because of potential reshape,e.g. reshaped to
524+
# [1, target_batch/source_batch, source_batch, hidden_dims].
525+
target_batch = inputs.query_vec.shape.num_elements() // p.hidden_dim
522526
query_segment_id = tf.zeros(
523-
tf.shape(inputs.query_vec)[0], dtype=source_padding.dtype
527+
[target_batch], dtype=source_padding.dtype
524528
)
525529
else:
526530
query_segment_id = None

0 commit comments

Comments
 (0)
Please sign in to comment.