diff --git a/textless/data/cpc_feature_reader.py b/textless/data/cpc_feature_reader.py index d19abea..e39483b 100644 --- a/textless/data/cpc_feature_reader.py +++ b/textless/data/cpc_feature_reader.py @@ -47,7 +47,8 @@ def get_features(self, x: torch.Tensor) -> torch.Tensor: start += self.max_chunk if start < size: - x_chunk = x[:, -self.max_chunk :] + x_chunk = x[..., -self.max_chunk :] # dimension wrong, \ + # compare with above feat_chunk = self.model.extract_features( source=x_chunk, get_encoded=self.use_encoder_layer,