Skip to content

Commit

Permalink
fix accessors default edit function (Avaiga#1584)
Browse files Browse the repository at this point in the history
* fix accessors default edit function
add tests

* cannot rely on name of the class as pandas and polars bothhave DataFrame

* F402

* F402

---------

Co-authored-by: Fred Lefévère-Laoide <[email protected]>
  • Loading branch information
FredLL-Avaiga and Fred Lefévère-Laoide authored Jul 26, 2024
1 parent 6c8c854 commit e7908d6
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 23 deletions.
8 changes: 4 additions & 4 deletions taipy/gui/data/array_dict_data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class _ArrayDictDataAccessor(_PandasDataAccessor):
__types = (dict, list, tuple, _MapDict)

@staticmethod
def get_supported_classes() -> t.List[str]:
return [t.__name__ for t in _ArrayDictDataAccessor.__types] # type: ignore
def get_supported_classes() -> t.List[t.Type]:
return list(_ArrayDictDataAccessor.__types)

def to_pandas(self, value: t.Any) -> t.Union[t.List[pd.DataFrame], pd.DataFrame]:
if isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -54,9 +54,9 @@ def to_pandas(self, value: t.Any) -> t.Union[t.List[pd.DataFrame], pd.DataFrame]

def _from_pandas(self, value: pd.DataFrame, type: t.Type):
if type is dict:
return value.to_dict()
return value.to_dict("list")
if type is _MapDict:
return _MapDict(value.to_dict())
return _MapDict(value.to_dict("list"))
if len(value.columns) == 1:
if type is list:
return value.iloc[:, 0].to_list()
Expand Down
24 changes: 12 additions & 12 deletions taipy/gui/data/data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, gui: "Gui") -> None:

@staticmethod
@abstractmethod
def get_supported_classes() -> t.List[str]:
def get_supported_classes() -> t.List[t.Type]:
pass

@abstractmethod
Expand Down Expand Up @@ -65,8 +65,8 @@ def to_csv(self, var_name: str, value: t.Any):

class _InvalidDataAccessor(_DataAccessor):
@staticmethod
def get_supported_classes() -> t.List[str]:
return [type(None).__name__]
def get_supported_classes() -> t.List[t.Type]:
return []

def get_data(
self, var_name: str, value: t.Any, payload: t.Dict[str, t.Any], data_format: _DataFormat
Expand Down Expand Up @@ -94,7 +94,7 @@ def to_csv(self, var_name: str, value: t.Any):

class _DataAccessors(object):
def __init__(self, gui: "Gui") -> None:
self.__access_4_type: t.Dict[str, _DataAccessor] = {}
self.__access_4_type: t.Dict[t.Type, _DataAccessor] = {}
self.__invalid_data_accessor = _InvalidDataAccessor(gui)
self.__data_format = _DataFormat.JSON
self.__gui = gui
Expand All @@ -112,13 +112,13 @@ def _register(self, cls: t.Type[_DataAccessor]) -> None:
raise AttributeError("The argument of 'DataAccessors.register' should be a class")
if not issubclass(cls, _DataAccessor):
raise TypeError(f"Class {cls.__name__} is not a subclass of DataAccessor")
names = cls.get_supported_classes()
if not names:
classes = cls.get_supported_classes()
if not classes:
raise TypeError(f"method {cls.__name__}.get_supported_classes returned an invalid value")
# check existence
inst: t.Optional[_DataAccessor] = None
for name in names:
inst = self.__access_4_type.get(name)
for cl in classes:
inst = self.__access_4_type.get(cl)
if inst:
break
if inst is None:
Expand All @@ -127,15 +127,15 @@ def _register(self, cls: t.Type[_DataAccessor]) -> None:
except Exception as e:
raise TypeError(f"Class {cls.__name__} cannot be instantiated") from e
if inst:
for name in names:
self.__access_4_type[name] = inst # type: ignore
for cl in classes:
self.__access_4_type[cl] = inst # type: ignore

def __get_instance(self, value: _TaipyData) -> _DataAccessor: # type: ignore
value = value.get() if isinstance(value, _TaipyData) else value
access = self.__access_4_type.get(type(value).__name__)
access = self.__access_4_type.get(type(value))
if access is None:
if value is not None:
_warn(f"Can't find Data Accessor for type {type(value).__name__}.")
_warn(f"Can't find Data Accessor for type {str(type(value))}.")
return self.__invalid_data_accessor
return access

Expand Down
4 changes: 2 additions & 2 deletions taipy/gui/data/numpy_data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class _NumpyDataAccessor(_PandasDataAccessor):
__types = (numpy.ndarray,)

@staticmethod
def get_supported_classes() -> t.List[str]:
return [t.__name__ for t in _NumpyDataAccessor.__types] # type: ignore
def get_supported_classes() -> t.List[t.Type]:
return list(_NumpyDataAccessor.__types)

def to_pandas(self, value: t.Any) -> pd.DataFrame:
return pd.DataFrame(value)
Expand Down
11 changes: 6 additions & 5 deletions taipy/gui/data/pandas_data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def _from_pandas(self, value: pd.DataFrame, data_type: t.Type):
return value

@staticmethod
def get_supported_classes() -> t.List[str]:
return [t.__name__ for t in _PandasDataAccessor.__types] # type: ignore
def get_supported_classes() -> t.List[t.Type]:
return list(_PandasDataAccessor.__types)

@staticmethod
def __user_function(
Expand Down Expand Up @@ -484,12 +484,13 @@ def on_add(self, value: t.Any, payload: t.Dict[str, t.Any], new_row: t.Optional[
new_row = [0 if is_numeric_dtype(df[c]) else "" for c in list(col_types)] if new_row is None else new_row
if index > 0:
# Column names and value types must match the original DataFrame
new_df = pd.DataFrame(new_row, columns=list(col_types))
new_df = pd.DataFrame([new_row], columns=list(col_types))
# Split the DataFrame
rows_before = df.loc[: index - 1]
rows_after = df.loc[index + 1 :]
rows_before = df.iloc[:index]
rows_after = df.iloc[index:]
return self._from_pandas(pd.concat([rows_before, new_df, rows_after], ignore_index=True), type(value))
else:
df = df.copy()
# Insert as the new first row
df.loc[-1] = new_row # Insert the new row
df.index = df.index + 1 # Shift index
Expand Down
59 changes: 59 additions & 0 deletions tests/gui/data/test_array_dict_data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import os
from importlib import util

from taipy.gui import Gui
Expand Down Expand Up @@ -161,3 +162,61 @@ def test_array_of_Mapdicts(gui: Gui, helpers, small_dataframe):
assert len(data) == 2
assert len(data[0]["temperatures"]) == 5
assert len(data[1]["seasons"]) == 4


def test_edit_dict(gui, small_dataframe):
accessor = _ArrayDictDataAccessor(gui)
pd = small_dataframe
ln = len(pd["name"])
assert pd["value"][0] != 10
ret_data = accessor.on_edit(pd, {"index": 0, "col": "value", "value": 10})
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln
assert ret_data["value"][0] == 10


def test_delete_dict(gui, small_dataframe):
accessor = _ArrayDictDataAccessor(gui)
pd = small_dataframe
ln = len(pd['name'])
ret_data = accessor.on_delete(pd, {"index": 0})
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln - 1


def test_add_dict(gui, small_dataframe):
accessor = _ArrayDictDataAccessor(gui)
pd = small_dataframe
ln = len(pd["name"])

ret_data = accessor.on_add(pd, {"index": 0})
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln + 1
assert ret_data["value"][0] == 0
assert ret_data["name"][0] == ""

ret_data = accessor.on_add(pd, {"index": 2})
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln + 1
assert ret_data["value"][2] == 0
assert ret_data["name"][2] == ""

ret_data = accessor.on_add(pd, {"index": 0}, ["New", 100])
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln + 1
assert ret_data["value"][0] == 100
assert ret_data["name"][0] == "New"

ret_data = accessor.on_add(pd, {"index": 2}, ["New", 100])
assert isinstance(ret_data, dict)
assert len(ret_data["name"]) == ln + 1
assert ret_data["value"][2] == 100
assert ret_data["name"][2] == "New"


def test_csv(gui, small_dataframe):
accessor = _ArrayDictDataAccessor(gui)
pd = small_dataframe
path = accessor.to_csv("", pd)
assert path is not None
assert os.path.getsize(path) > 0
59 changes: 59 additions & 0 deletions tests/gui/data/test_pandas_data_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# specific language governing permissions and limitations under the License.

import inspect
import os
from datetime import datetime
from importlib import util

Expand Down Expand Up @@ -239,3 +240,61 @@ def test_decimator(gui: Gui, helpers, small_dataframe):
assert value
data = value["data"]
assert len(data) == 2


def test_edit(gui, small_dataframe):
accessor = _PandasDataAccessor(gui)
pd = pandas.DataFrame(small_dataframe)
ln = len(pd)
assert pd["value"].iloc[0] != 10
ret_data = accessor.on_edit(pd, {"index": 0, "col": "value", "value": 10})
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln
assert ret_data["value"].iloc[0] == 10


def test_delete(gui, small_dataframe):
accessor = _PandasDataAccessor(gui)
pd = pandas.DataFrame(small_dataframe)
ln = len(pd)
ret_data = accessor.on_delete(pd, {"index": 0})
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln - 1


def test_add(gui, small_dataframe):
accessor = _PandasDataAccessor(gui)
pd = pandas.DataFrame(small_dataframe)
ln = len(pd)

ret_data = accessor.on_add(pd, {"index": 0})
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln + 1
assert ret_data["value"].iloc[0] == 0
assert ret_data["name"].iloc[0] == ""

ret_data = accessor.on_add(pd, {"index": 2})
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln + 1
assert ret_data["value"].iloc[2] == 0
assert ret_data["name"].iloc[2] == ""

ret_data = accessor.on_add(pd, {"index": 0}, ["New", 100])
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln + 1
assert ret_data["value"].iloc[0] == 100
assert ret_data["name"].iloc[0] == "New"

ret_data = accessor.on_add(pd, {"index": 2}, ["New", 100])
assert isinstance(ret_data, pandas.DataFrame)
assert len(ret_data) == ln + 1
assert ret_data["value"].iloc[2] == 100
assert ret_data["name"].iloc[2] == "New"


def test_csv(gui, small_dataframe):
accessor = _PandasDataAccessor(gui)
pd = pandas.DataFrame(small_dataframe)
path = accessor.to_csv("", pd)
assert path is not None
assert os.path.getsize(path) > 0

0 comments on commit e7908d6

Please sign in to comment.