File tree 1 file changed +8
-4
lines changed
1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -492,8 +492,8 @@ def _GetSourceAndQuerySegmentIds(
492
492
Returns:
493
493
A tuple of 2 elements.
494
494
495
- - The source segment id tensor: [time, batch_size ].
496
- - The query segment id tensor: [batch_size ].
495
+ - The source segment id tensor: [time, source_batch ].
496
+ - The query segment id tensor: [target_batch ].
497
497
"""
498
498
p = self .params
499
499
if p .packed_input :
@@ -518,9 +518,13 @@ def _GetSourceAndQuerySegmentIds(
518
518
' a default all-zero value instead, packed_input will be'
519
519
' ineffective.'
520
520
)
521
- if source_padding is not None :
521
+ if source_padding is not None and inputs .query_vec is not None :
522
+ # query_vec.shape could be different from [target_batch, query_dim]
523
+ # because of potential reshape,e.g. reshaped to
524
+ # [1, target_batch/source_batch, source_batch, hidden_dims].
525
+ target_batch = inputs .query_vec .shape .num_elements () // p .hidden_dim
522
526
query_segment_id = tf .zeros (
523
- tf . shape ( inputs . query_vec )[ 0 ], dtype = source_padding .dtype
527
+ [ target_batch ], dtype = source_padding .dtype
524
528
)
525
529
else :
526
530
query_segment_id = None
You can’t perform that action at this time.
0 commit comments