diff --git a/src/nested_pandas/nestedframe/core.py b/src/nested_pandas/nestedframe/core.py index 5cd0a708..4ed7240c 100644 --- a/src/nested_pandas/nestedframe/core.py +++ b/src/nested_pandas/nestedframe/core.py @@ -97,16 +97,14 @@ def all_columns(self) -> dict: @property def nested_columns(self) -> list: """retrieves the base column names for all nested dataframes""" - nest_cols = [] - for column in self.columns: - if isinstance(self.dtypes[column], NestedDtype): - nest_cols.append(column) - return nest_cols + nested_mask = self.dtypes.apply(lambda dtype: isinstance(dtype, NestedDtype)) + return self.columns[nested_mask].tolist() @property def base_columns(self) -> list[str]: """Returns the list of base (non-nested) column names""" - return [col for col in self.columns if col not in self.nested_columns] + nested_mask = self.dtypes.apply(lambda dtype: not isinstance(dtype, NestedDtype)) + return self.columns[nested_mask].tolist() def _repr_html_(self) -> str | None: """Override html representation"""