Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions abspy/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,14 @@ def get_points(self, row=1):
as_float: (n, 3) float
Point cloud
"""
pc_lines = []

for line in self.vgroup_ascii[row:]:
end_row = len(self.vgroup_ascii)
for i, line in enumerate(self.vgroup_ascii[row:], start=row):
# stop reading when the 'num_colors' keyword is found
if 'num_colors' in line:
end_row = i
break
pc_lines.append(line)

pc = np.fromstring(' '.join(pc_lines), sep=' ')
pc = np.fromstring(' '.join(self.vgroup_ascii[row:end_row]), sep=' ')
return np.reshape(pc, (-1, 3))

def get_primitives(self):
Expand All @@ -217,11 +216,16 @@ def get_primitives(self):
ungrouped_points: (u, 3) float
Points that belong to no group
"""
is_primitive = [line.startswith('group_num_point') for line in self.vgroup_ascii]
is_parameter = [line.startswith('group_parameters') for line in self.vgroup_ascii]
primitive_line_indices = []
parameter_line_indices = []
for i, line in enumerate(self.vgroup_ascii):
if line.startswith('group_num_point'):
primitive_line_indices.append(i + 1)
elif line.startswith('group_parameters'):
parameter_line_indices.append(i)

primitives = [self.vgroup_ascii[line] for line in np.where(is_primitive)[0] + 1] # lines of groups in the file
parameters = [self.vgroup_ascii[line] for line in np.where(is_parameter)[0]]
primitives = [self.vgroup_ascii[i] for i in primitive_line_indices] # lines of groups in the file
parameters = [self.vgroup_ascii[i] for i in parameter_line_indices]

# remove global group if there is one
if self.global_group:
Expand All @@ -245,7 +249,7 @@ def get_primitives(self):
# empty group -> global bounds and no refit
if self.refit:
logger.warning('refit skipped for empty group')
param = np.array([float(j) for j in parameters[i][18:-1].split()])
param = self._parse_group_parameters(parameters[i])
aabb = self._points_bound(self.points)
obb = aabb

Expand All @@ -254,7 +258,7 @@ def get_primitives(self):
if self.refit:
param, obb = self.fit_plane(points, mode='PCA')
else:
param = np.array([float(j) for j in parameters[i][18:-1].split()])
param = self._parse_group_parameters(parameters[i])
_, obb = self.fit_plane(points, mode='PCA')
aabb = self._points_bound(points)

Expand Down Expand Up @@ -288,6 +292,17 @@ def _points_bound(points):
"""
return np.array([np.amin(points, axis=0), np.amax(points, axis=0)])

@staticmethod
def _parse_group_parameters(line):
"""
Parse group_parameters line from vg/bvg with or without trailing newline.
"""
if line.startswith('group_parameters:'):
content = line[len('group_parameters:'):]
else:
content = line
return np.fromstring(content.strip(), sep=' ')

def normalise_from_centroid_and_scale(self, centroid, scale, num=None):
"""
Normalising points.
Expand Down Expand Up @@ -380,6 +395,8 @@ def fit_plane(points, mode='PCA'):
"""
assert mode == 'PCA' or mode == 'LSA'

points = np.asarray(points)

if len(points) < 3:
logger.warning('plane fitting skipped given #points={}'.format(len(points)))
return None
Expand All @@ -396,14 +413,15 @@ def fit_plane(points, mode='PCA'):
pca = PCA(n_components=3)
pca.fit(points)
eig_vec = pca.components_
points_trans = pca.transform(points)
# equivalent to pca.transform(points) but avoids repeated input validation
points_trans = (points - pca.mean_) @ eig_vec.T
point_min = np.amin(points_trans, axis=0)
Comment on lines +416 to 418
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fit_plane now computes points_trans with points @ eig_vec.T ..., which requires points to be a NumPy array supporting @. Previously pca.transform(points) accepted any array-like input (lists, tuples) via scikit-learn validation. Since fit_plane is a documented public API, consider converting points = np.asarray(points) (or falling back to pca.transform) to avoid an unintended API regression.

Copilot uses AI. Check for mistakes.
point_max = np.amax(points_trans, axis=0)
obb = np.array([[point_min[0], point_min[1], 0], [point_min[0], point_max[1], 0],
[point_max[0], point_max[1], 0], [point_max[0], point_min[1], 0]])
obb = pca.inverse_transform(obb)

logger.debug('explained_variance_ratio: {}'.format(pca.explained_variance_ratio_))
logger.debug('explained_variance_ratio: %s', pca.explained_variance_ratio_)

# normal vector of minimum variance
normal = eig_vec[2, :] # (a, b, c) normalized
Expand Down
37 changes: 37 additions & 0 deletions tests/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ def test_fit_plane_with_lsa():
assert abs(param[3] + 1) < 0.1 # d should be close to -1


def test_fit_plane_accepts_array_like_input():
"""Test fit_plane with plain Python list input."""
points = [
[0, 0, 1],
[1, 0, 1],
[0, 1, 1],
[1, 1, 1]
]

param, obb = VertexGroup.fit_plane(points)
assert param is not None
assert obb is not None
assert abs(param[2]) > 0.9
assert abs(param[3] + 1) < 0.1


def test_fit_plane_with_few_points():
"""Test fit_plane method with too few points."""
# Only 2 points - not enough to fit a plane
Expand All @@ -224,6 +240,27 @@ def test_fit_plane_with_few_points():
assert result is None or result[0] is None


def test_parse_group_parameters_without_trailing_newline():
"""Test parsing group parameters when line has no trailing newline."""
vg = VertexGroup.__new__(VertexGroup)
vg.vgroup_ascii = [
'group_parameters: 1.0 0.0 0.0 0.123456',
'group_num_point: 3',
'0 1 2'
]
vg.points = np.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]
])
vg.refit = False
vg.global_group = False

planes, _, _, _, _ = vg.get_primitives()
assert planes.shape == (1, 4)
np.testing.assert_allclose(planes[0], [1.0, 0.0, 0.0, 0.123456])


def test_get_points_with_custom_row():
"""Test the get_points method with a custom row parameter."""
# Create a temporary file with points data in non-default location
Expand Down
Loading