diff --git a/roboreg/util/transform.py b/roboreg/util/transform.py index 1384fa9..4362c5c 100644 --- a/roboreg/util/transform.py +++ b/roboreg/util/transform.py @@ -64,7 +64,7 @@ def generate_ht_optical( ) -> torch.Tensor: ht_optical = torch.zeros(4, 4, dtype=dtype, device=device) if batch_size is not None: - ht_optical = ht_optical.unsqueeze(0).expand(batch_size, -1, -1) + ht_optical = ht_optical.unsqueeze(0).repeat(batch_size, 1, 1) ht_optical[..., 0, 2] = ( 1.0 # OpenCV-oriented optical frame, in quaternions: [0.5, -0.5, 0.5, -0.5] (w, x, y, z) )