diff --git a/conda/conda-reqs-pip.txt b/conda/conda-reqs-pip.txt index 19f1ce5d8..4b8eaab2e 100644 --- a/conda/conda-reqs-pip.txt +++ b/conda/conda-reqs-pip.txt @@ -1,8 +1,10 @@ azure-mgmt-resourcegraph>=8.0.0 azure-monitor-query>=1.0.0, <=2.0.0 +dataclasses-json >= 0.5.7 # KqlmagicCustom[jupyter-basic,auth_code_clipboard]>=0.1.114.post22 mo-sql-parsing>=8, <9.0.0 +nbformat>=5.9.2 nest_asyncio>=1.4.0 passivetotal>=2.5.3 -sumologic-sdk>=0.1.11 splunk-sdk>=1.6.0 +sumologic-sdk>=0.1.11 diff --git a/msticpy/common/cache/__init__.py b/msticpy/common/cache/__init__.py new file mode 100644 index 000000000..57b21bd94 --- /dev/null +++ b/msticpy/common/cache/__init__.py @@ -0,0 +1,86 @@ +"""Common methods to handle caching.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from ...datamodel.result import QueryResult +from ..utility.ipython import is_ipython +from . import cell +from . import file as cache_file +from .codec import compute_digest + +if TYPE_CHECKING: + import pandas as pd + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def write_cache( # noqa: PLR0913 + data: pd.DataFrame, + search_params: dict[str, Any], + query: str, + name: str, + cache_path: str | None = None, + *, + display: bool = False, +) -> None: + """Cache query result in a cell or a parquet file.""" + cache_digest: str = compute_digest(search_params) + cache: QueryResult = QueryResult( + name=name, + query=query, + raw_results=data, + arguments=search_params, + ) + if is_ipython() and display: + cell.write_cache( + cache, + name, + cache_digest, + ) + if cache_path: + LOGGER.info("Writing cache to %s", cache_path) + cache_file.write_cache( + data=cache, + file_name=f"{name}_{cache_digest}", + export_folder=cache_path, + ) + + +def read_cache( + search_params: dict[str, Any], + cache_path: str | None, + name: str | None = None, +) -> QueryResult: + """Retrieve result from cache in a cell or a archive file.""" + if not cache_path: + error_msg: str = "Cache not provided." + raise ValueError(error_msg) + cache_digest: str = compute_digest(search_params) + if is_ipython(): + try: + return cell.read_cache( + name or cache_digest, + cache_digest, + cache_path, + ) + except ValueError: + pass + try: + cache: QueryResult = cache_file.read_cache( + f"{name}_{cache_digest}", + cache_path, + ) + except FileNotFoundError as exc: + error_msg = "Could not read from cache." + raise ValueError(error_msg) from exc + if is_ipython(): + # Writing cache to cell since it has not been found. + cell.write_cache( + cache, + name or cache_digest, + cache_digest, + ) + return cache diff --git a/msticpy/common/cache/cell.py b/msticpy/common/cache/cell.py new file mode 100644 index 000000000..243f17b8d --- /dev/null +++ b/msticpy/common/cache/cell.py @@ -0,0 +1,108 @@ +"""Handle caching in Notebook cell.""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import nbformat +from IPython.display import display + +from ...datamodel.result import QueryResult +from .codec import decode_base64_as_pickle, encode_as_base64_pickle + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def write_cache( + data: QueryResult, + name: str, + digest: str, +) -> None: + """Cache content in cell.""" + cache: str = encode_as_base64_pickle(data) + metadata: dict[str, Any] = { + "data": cache, + "hash": digest, + } + if isinstance(data, QueryResult): + metadata.update( + { + "name": name, + "query": data.query, + "arguments": data.arguments, + "timestamp": data.timestamp, + }, + ) + LOGGER.debug("Data %s written to Notebook cache", name) + display( + data.raw_results, + metadata=metadata, + exclude=["text/plain"], + ) + + +def get_cache_item(path: Path, name: str, digest: str) -> dict[str, Any]: + """ + Get named object from cache. + + Parameters + ---------- + path : Path + Path to notebook + name : str + name of the cached object to search + digest : str + Hash of the cached object to search + + Returns + ------- + dict[str, Any] + Cached object. + """ + if not path.exists(): + error_msg: str = "Notebook not found" + raise FileNotFoundError(error_msg) + + notebook: nbformat.NotebookNode = nbformat.reads( + path.read_text(encoding="utf-8"), + as_version=nbformat.current_nbformat, + ) + + try: + cache: dict[str, Any] = next( + iter( + [ + (output.get("metadata", {}) or {}) + for cell in (notebook.cells or []) + for output in (cell.get("outputs", []) or []) + if output.get("metadata", {}).get("hash") == digest + and output.get("metadata", {}).get("name") == name + ], + ), + ) + except StopIteration: + LOGGER.debug("%s not found in %s cache...", digest, path) + cache = {} + + return cache + + +def read_cache(name: str, digest: str, nb_path: str) -> QueryResult: + """Read cache content from file.""" + if not nb_path: + error_msg: str = "Argument nb_path must be defined." + raise ValueError(error_msg) + + notebook_fp: Path = Path(nb_path).absolute() + + if not notebook_fp.exists(): + error_msg = "Notebook not found" + raise FileNotFoundError(error_msg) + + cache: dict[str, Any] = get_cache_item(path=notebook_fp, name=name, digest=digest) + if cache and (data := cache.get("data")): + LOGGER.debug("Digest %s found in cache...", digest) + return decode_base64_as_pickle(data) + error_msg = f"Cache {digest} not found" + raise ValueError(error_msg) diff --git a/msticpy/common/cache/codec.py b/msticpy/common/cache/codec.py new file mode 100644 index 000000000..f0d3164e1 --- /dev/null +++ b/msticpy/common/cache/codec.py @@ -0,0 +1,40 @@ +"""Functions to encode/decode cached objects.""" + +import base64 +import json +import logging +from collections.abc import MutableMapping +from hashlib import sha256 +from io import BytesIO + +import compress_pickle # type: ignore[import-untyped] + +from ...datamodel.result import QueryResult + +from ..._version import VERSION + +__version__ = VERSION +__author__ = "Florian Bracq" + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def encode_as_base64_pickle(data: QueryResult) -> str: + """Encode data as Base64 pickle to be written to cache.""" + with BytesIO() as bytes_io: + compress_pickle.dump(data, bytes_io, compression="lzma") + return base64.b64encode(bytes_io.getvalue()).decode() + + +def decode_base64_as_pickle(b64_string: str) -> QueryResult: + """Decode Base64 pickle from cache to Results.""" + return compress_pickle.loads(base64.b64decode(b64_string), compression="lzma") + + +def compute_digest(obj: MutableMapping) -> str: + """Compute the digest from the parameters.""" + str_params: str = json.dumps(obj, sort_keys=True, default=str) + LOGGER.debug("Received: %s", str_params) + digest: str = sha256(bytes(str_params, "utf-8")).hexdigest() + LOGGER.debug("Generated digest: %s", digest) + return digest diff --git a/msticpy/common/cache/file.py b/msticpy/common/cache/file.py new file mode 100644 index 000000000..b2c0f99e0 --- /dev/null +++ b/msticpy/common/cache/file.py @@ -0,0 +1,48 @@ +"""Handle caching in files.""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from .codec import decode_base64_as_pickle, encode_as_base64_pickle + +if TYPE_CHECKING: + from ...datamodel.result import QueryResult + + +LOGGER: logging.Logger = logging.getLogger(__name__) +CACHE_FOLDER_NAME = "artifacts" + + +def write_cache( + data: QueryResult, + file_name: str, + export_folder: str = CACHE_FOLDER_NAME, +) -> None: + """Cache content in file.""" + export_path: Path = Path(export_folder) + if export_path.is_file(): + export_path = export_path.parent / CACHE_FOLDER_NAME + if not export_path.exists(): + export_path.mkdir(exist_ok=True, parents=True) + export_file: Path = export_path / file_name + encoded_text: str = encode_as_base64_pickle(data) + export_file.write_text(encoded_text) + LOGGER.debug("Data written to file %s", export_folder) + + +def read_cache( + file_name: str, + export_folder: str = CACHE_FOLDER_NAME, +) -> QueryResult: + """Read cache content from file.""" + export_path: Path = Path(export_folder) + if export_path.is_file(): + export_path = export_path.parent / CACHE_FOLDER_NAME + export_file: Path = export_path / file_name + if export_file.exists(): + LOGGER.debug("Found data in cache %s", export_file) + encoded_text: str = export_file.read_text() + return decode_base64_as_pickle(encoded_text) + raise FileNotFoundError diff --git a/msticpy/data/core/data_providers.py b/msticpy/data/core/data_providers.py index f9ec3455e..2b622302d 100644 --- a/msticpy/data/core/data_providers.py +++ b/msticpy/data/core/data_providers.py @@ -12,8 +12,10 @@ import pandas as pd from ..._version import VERSION +from ...common.cache import read_cache, write_cache from ...common.pkg_config import get_config from ...common.utility import export, valid_pyname +from ...datamodel.result import QueryResult from ...nbwidgets.query_time import QueryTime from .. import drivers from ..drivers.driver_base import DriverBase, DriverProps @@ -267,6 +269,7 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: ) query_name = kwargs.pop("query_name") family = kwargs.pop("query_path") + cache_path: Optional[str] = kwargs.pop("cache_path", None) query_source = self.query_store.get_query( query_path=family, query_name=query_name @@ -299,6 +302,7 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: if split_result is not None: return split_result # if split queries could not be created, fall back to default + query_str = query_source.create_query( formatters=self._query_provider.formatters, **params ) @@ -311,7 +315,36 @@ def _execute_query(self, *args, **kwargs) -> Union[pd.DataFrame, Any]: logger.info( "Running query '%s...' with params: %s", query_str[:40], query_options ) - return self.exec_query(query_str, query_source=query_source, **query_options) + if cache_path: + try: + result: QueryResult = read_cache( + query_options, + cache_path, + query_source.name, + ) + except (ValueError, FileNotFoundError): + logger.info("Data not found in cache.") + else: + logger.info( + "Data found in cache, returning result from past execution %s.", + result.timestamp.isoformat(sep=" ", timespec="seconds"), + ) + if result.raw_results is not None: + return result.raw_results + + query_result: pd.DataFrame = self.exec_query( + query_str, query_source=query_source, **query_options + ) + + write_cache( + data=query_result, + query=query_str, + search_params=query_options, + cache_path=cache_path, + name=query_source.name, + display=kwargs.pop("display", True), + ) + return query_result def _check_for_time_params(self, params, missing) -> bool: """Fall back on builtin query time if no time parameters were supplied.""" diff --git a/msticpy/datamodel/result.py b/msticpy/datamodel/result.py new file mode 100644 index 000000000..554ac85d3 --- /dev/null +++ b/msticpy/datamodel/result.py @@ -0,0 +1,60 @@ +"""Define standard models for query results.""" +from __future__ import annotations + +import datetime as dt +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from dataclasses_json import dataclass_json +from typing_extensions import Self + +if TYPE_CHECKING: + from collections.abc import Hashable + + import pandas as pd + + +@dataclass_json +@dataclass +class QueryResult: + """DataFrame model.""" + + name: str + query: str + raw_results: pd.DataFrame + arguments: dict[str, Any] = field(default_factory=dict) + timestamp: dt.datetime = field(default_factory=dt.datetime.utcnow) + + @property + def normalizer(self: Self) -> str: + """Normalizer class name.""" + return str(self.__class__.__name__) + + @property + def total_results(self: Self) -> int: + """Total Results.""" + return len(self.results) + + @property + def results(self: Self) -> list[dict[Hashable, Any]]: + """Return results as a List of dicts.""" + return self.raw_results.to_dict(orient="records") + + def _repr_markdown_(self: Self) -> str: + """Represent as markdown.""" + return self.raw_results.to_html(index=False) + + def _repr_html_(self: Self) -> str: + """Represent as HTML.""" + return self.raw_results.to_html(index=False) + + def __eq__(self: Self, __value: object) -> bool: + """Return True if self and __value are equal.""" + if not isinstance(__value, QueryResult): + return False + return ( + self.name == __value.name + and self.query == __value.query + and len(self.arguments) == len(__value.arguments) + and self.raw_results.equals(__value.raw_results) + ) diff --git a/requirements-all.txt b/requirements-all.txt index 727c282e7..0c8b3ed4e 100644 --- a/requirements-all.txt +++ b/requirements-all.txt @@ -16,7 +16,9 @@ azure-monitor-query>=1.0.0, <=2.0.0 azure-storage-blob>=12.5.0 beautifulsoup4>=4.0.0 bokeh>=1.4.0, <4.0.0 +compress-pickle >= 2.1.0 cryptography>=3.1 +dataclasses-json >= 0.5.7 deprecated>=1.2.4 dnspython>=2.0.0, <3.0.0 folium>=0.9.0 @@ -35,6 +37,7 @@ msal>=1.12.0 msal_extensions>=0.3.0 msrest>=0.6.0 msrestazure>=0.6.0 +nbformat>=5.9.2 nest_asyncio>=1.4.0 networkx>=2.2 numpy>=1.15.4 # pandas diff --git a/requirements.txt b/requirements.txt index eea84c4fb..1c7583a7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,9 @@ azure-mgmt-subscription>=3.0.0 azure-monitor-query>=1.0.0, <=2.0.0 beautifulsoup4>=4.0.0 bokeh>=1.4.0, <4.0.0 +compress-pickle >= 2.1.0 cryptography>=3.1 +dataclasses-json >= 0.5.7 deprecated>=1.2.4 dnspython>=2.0.0, <3.0.0 folium>=0.9.0 @@ -25,6 +27,7 @@ msal>=1.12.0 msal_extensions>=0.3.0 msrest>=0.6.0 msrestazure>=0.6.0 +nbformat>=5.9.2 nest_asyncio>=1.4.0 networkx>=2.2 numpy>=1.15.4 # pandas diff --git a/tests/common/cache/__init__.py b/tests/common/cache/__init__.py new file mode 100644 index 000000000..7cdd7c8cc --- /dev/null +++ b/tests/common/cache/__init__.py @@ -0,0 +1 @@ +"""Tests for cache functions.""" diff --git a/tests/common/cache/test_cell.py b/tests/common/cache/test_cell.py new file mode 100644 index 000000000..aa43d5d79 --- /dev/null +++ b/tests/common/cache/test_cell.py @@ -0,0 +1,228 @@ +"""Testing cell notebook cache.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import patch + +import nbformat +import pandas as pd +import pytest +import pytest_check as check +from typing_extensions import Self + +from msticpy.common.cache import cell +from msticpy.datamodel.result import QueryResult + +if TYPE_CHECKING: + from pathlib import Path + + +def test_pickle_encode_decode(simple_dataframeresult: QueryResult) -> None: + """Test method encode_as_base64_pickle and decode_base64_as_pickle.""" + encoded: str = cell.encode_as_base64_pickle(simple_dataframeresult) + check.is_not_none(encoded) + check.is_instance(encoded, str) + check.greater(len(encoded), 100) + + decoded: QueryResult = cell.decode_base64_as_pickle(encoded) + check.is_not_none(decoded) + check.is_instance(decoded, QueryResult) + check.is_false(decoded.raw_results.empty) + check.is_true(simple_dataframeresult.raw_results.equals(decoded.raw_results)) + + +def test_write_cache(simple_dataframeresult: QueryResult) -> None: + """Test method write_cache.""" + digest: str = "digest" + name: str = "name" + + with patch.object(cell, "display") as patched_display: + cell.write_cache( + simple_dataframeresult, + name, + digest, + ) + check.is_true(patched_display.called) + check.equal(patched_display.call_count, 1) + + check.equal(len(patched_display.call_args.args), 1) + + data: pd.DataFrame = patched_display.call_args.args[0] + check.is_instance(data, pd.DataFrame) + + kwargs: dict[str, Any] = patched_display.call_args.kwargs + check.is_not_none(kwargs) + check.is_instance(kwargs, dict) + check.equal(len(kwargs), 2) + check.is_in("metadata", kwargs) + check.is_in("exclude", kwargs) + + metadata: dict[str, str] = kwargs["metadata"] + check.is_instance(metadata, dict) + check.equal(len(metadata), 6) + check.is_in("data", metadata) + check.is_in("hash", metadata) + check.is_in("name", metadata) + check.is_in("query", metadata) + check.is_in("arguments", metadata) + check.is_in("timestamp", metadata) + + meta_data: str = metadata["data"] + check.equal(meta_data, cell.encode_as_base64_pickle(simple_dataframeresult)) + meta_hash: str = metadata["hash"] + check.equal(meta_hash, digest) + meta_name: str = metadata["name"] + check.equal(meta_name, name) + meta_query: str = metadata["query"] + check.equal(meta_query, simple_dataframeresult.query) + meta_args: str = metadata["arguments"] + check.equal(meta_args, simple_dataframeresult.arguments) + meta_timestamp: str = metadata["timestamp"] + check.equal(meta_timestamp, simple_dataframeresult.timestamp) + + +class MyNotebook: # pylint:disable=too-few-public-methods + """Dummy notebook class.""" + + def __init__(self: Self, metadata: dict[str, Any] | None = None) -> None: + """Init dummy object.""" + self.cells: list[dict[str, list[dict[str, dict[str, Any]]]]] = [ + {"outputs": [{"metadata": metadata or {}}]}, + ] + + +def test_get_cache_item(tmp_path: Path) -> None: + """Test method get_cache_item.""" + digest: str = "digest" + name: str = "name" + + # Create file with digest content + (tmp_path / "random.ipynb").write_text(digest, encoding="utf-8") + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": digest, "name": name}), + ): + res: dict[str, Any] = cell.get_cache_item( + tmp_path / "random.ipynb", + name=name, + digest=digest, + ) + check.is_instance(res, dict) + check.is_in("hash", res) + check.is_in("name", name) + check.equal(res["hash"], digest) + check.equal(res["name"], name) + + +def test_get_cache_item_wrong_path(tmp_path: Path) -> None: + """Test method get_cache_item with invalid notebook path.""" + with pytest.raises(FileNotFoundError, match="Notebook not found"): + cell.get_cache_item(tmp_path / "random.ipynb", "name", "digest") + + +def test_get_cache_item_wrong_digest(tmp_path: Path) -> None: + """Test method get_cache_item with invalid digest.""" + # Create file with some content + (tmp_path / "random.ipynb").write_text("", encoding="utf-8") + name: str = "name" + digest: str = "digest" + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": digest, "name": name}), + ): + res: dict[str, Any] = cell.get_cache_item(tmp_path / "random.ipynb", name, name) + check.is_instance(res, dict) + check.equal(len(res), 0) + + +def test_get_cache_item_wrong_name(tmp_path: Path) -> None: + """Test method get_cache_item with invalid name.""" + # Create file with some content + (tmp_path / "random.ipynb").write_text("", encoding="utf-8") + name: str = "name" + digest: str = "digest" + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": digest, "name": name}), + ): + res: dict[str, Any] = cell.get_cache_item( + tmp_path / "random.ipynb", + digest, + digest, + ) + check.is_instance(res, dict) + check.equal(len(res), 0) + + +def test_read_cache(tmp_path: Path, simple_dataframeresult: QueryResult) -> None: + """Test method read_cache.""" + nb_path: Path = tmp_path / "test.ipynb" + digest: str = "digest" + encoded: str = cell.encode_as_base64_pickle(simple_dataframeresult) + name: str = "name" + # Create file with no content + nb_path.write_text(digest, encoding="utf-8") + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": digest, "name": name, "data": encoded}), + ): + data: QueryResult = cell.read_cache( + name=name, + digest=digest, + nb_path=str(nb_path), + ) + check.is_instance(data, QueryResult) + check.is_true(simple_dataframeresult.raw_results.equals(data.raw_results)) + + +def test_read_cache_wrong_digest(tmp_path: Path) -> None: + """Test method read_cache with an incorrect digest.""" + nb_path: Path = tmp_path / "test.ipynb" + digest: str = "random" + name: str = "name" + # Create file with no content + nb_path.write_text(digest, encoding="utf-8") + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": digest, "name": name}), + ), pytest.raises(ValueError, match=f"Cache {name} not found"): + cell.read_cache(name, name, str(nb_path)) + + +def test_read_cache_wrong_name(tmp_path: Path) -> None: + """Test method read_cache with an incorrect name.""" + nb_path: Path = tmp_path / "test.ipynb" + digest: str = "random" + name: str = "name" + # Create file with no content + nb_path.write_text(digest, encoding="utf-8") + + with patch.object( + nbformat, + "reads", + return_value=MyNotebook({"hash": "hash", "name": name}), + ), pytest.raises(ValueError, match=f"Cache {digest} not found"): + cell.read_cache(digest, digest, str(nb_path)) + + +def test_read_cache_wrong_nb(tmp_path: Path) -> None: + """Test method read_cache with incorrect noteboob path.""" + nb_path: Path = tmp_path / "test.ipynb" + with pytest.raises(FileNotFoundError, match="Notebook not found"): + cell.read_cache("name", "digest", str(nb_path)) + + +def test_read_cache_no_nb() -> None: + """Test method read_cache with incorrect noteboob path.""" + with pytest.raises(ValueError, match="Argument nb_path must be defined."): + cell.read_cache("name", "digest", "") diff --git a/tests/common/cache/test_file.py b/tests/common/cache/test_file.py new file mode 100644 index 000000000..ee098418f --- /dev/null +++ b/tests/common/cache/test_file.py @@ -0,0 +1,98 @@ +"""Testing file notebook cache.""" +from pathlib import Path +from typing import Any + +import pandas as pd +import pytest +import pytest_check as check + +from msticpy.common.cache import file +from msticpy.datamodel.result import QueryResult + + +def test_read_write_cache(tmp_path: Path, simple_dataframeresult: QueryResult) -> None: + """Test method to read and write cache.""" + file_name: str = "digest" + file_path: Path = tmp_path / file_name + + file.write_cache(simple_dataframeresult, file_name, str(tmp_path)) + + check.is_true(file_path.exists()) + check.is_true(file_path.stat().st_size > 0) + check.is_true(file_path.is_file()) + + res: QueryResult = file.read_cache(file_name, export_folder=str(tmp_path)) + check.is_instance(res.raw_results, pd.DataFrame) + check.is_false(res.raw_results.empty) + try: + check.is_true(res.raw_results.equals(simple_dataframeresult.raw_results)) + except ValueError as exc: + df: pd.DataFrame = res.raw_results + for i in range(df.shape[0]): + ref: Any = df.value.iloc[i] + if isinstance(ref, dict): + check.is_instance( + simple_dataframeresult.raw_results.value.iloc[i], + dict, + ) + else: + error_msg = "DataFrame comparison is only working for dict." + raise NotImplementedError(error_msg) from exc + + +def test_read_cache_from_missing_file() -> None: + """Test to read cache when file does not exist.""" + file_name: str = "does_not_exist" + + with pytest.raises(FileNotFoundError): + file.read_cache(file_name) + + +def test_write_cache_without_export_path( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + simple_dataframeresult: QueryResult, +) -> None: + """Test to write cache without providing export path.""" + monkeypatch.chdir(tmp_path) + file_name: str = "digest" + file_path: Path = Path(file.CACHE_FOLDER_NAME) / file_name + + file.write_cache(simple_dataframeresult, file_name) + + check.is_true(file_path.exists()) + check.is_true(file_path.stat().st_size > 0) + check.is_true(file_path.is_file()) + + +def test_read_write_cache_with_file( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + simple_dataframeresult: QueryResult, +) -> None: + """Test to write cache when providing a file as an export path.""" + monkeypatch.chdir(tmp_path) + + export_path_file: Path = Path("random_file") + export_path_file.touch() + check.is_true(export_path_file.exists()) + check.is_true(export_path_file.is_file()) + + file_name: str = "digest" + file_path: Path = export_path_file.parent / file.CACHE_FOLDER_NAME / file_name + + file.write_cache( + simple_dataframeresult, + file_name, + export_folder=str(export_path_file), + ) + + check.is_true(file_path.exists()) + check.is_true(file_path.stat().st_size > 0) + check.is_true(file_path.is_file()) + + res: QueryResult = file.read_cache( + file_name, + export_folder=str(export_path_file), + ) + check.equal(simple_dataframeresult, res) diff --git a/tests/common/cache/test_init.py b/tests/common/cache/test_init.py new file mode 100644 index 000000000..ee47c0dc2 --- /dev/null +++ b/tests/common/cache/test_init.py @@ -0,0 +1,167 @@ +"""Testing generic cache methods.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import patch + +import pytest +import pytest_check as check + +from msticpy.common import cache +from msticpy.common.cache import cell, file, read_cache, write_cache + +if TYPE_CHECKING: + from pathlib import Path + + import pandas as pd + + from msticpy.datamodel.result import QueryResult + + +def test_write_cache_cell(tmp_path: Path, simple_dataframe: pd.DataFrame) -> None: + """Test method to write cache from a cell.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object( + cell, + "write_cache", + return_value=None, + ) as mocked_cache, patch.object( + cache, + "is_ipython", + return_value=True, + ) as mocked_ipython: + # "Real" tests are managed in the cell.write_cache test" + write_cache( + simple_dataframe, + params, + str(tmp_path), + "name", + display=True, + ) + + check.is_true(mocked_ipython.called) + check.equal(mocked_ipython.call_count, 1) + + check.is_true(mocked_cache.called) + check.equal(mocked_cache.call_count, 1) + + +def test_write_cache_file(tmp_path: Path, simple_dataframe: pd.DataFrame) -> None: + """Test method to read cache from a file.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object(cell, "write_cache", return_value=None): + # "Real" tests are managed in the file.write_cache test" + write_cache( + data=simple_dataframe, + search_params=params, + query="query", + cache_path=str(tmp_path), + name="name", + ) + + +def test_read_cache_no_path() -> None: + """Test method to read cache without an export path.""" + params: dict[str, Any] = {"key": "digest"} + + with pytest.raises(ValueError, match="Cache not provided."): + read_cache(params, "", "name") + + +def test_read_cache_cell( + tmp_path: Path, + simple_dataframeresult: QueryResult, +) -> None: + """Test method to read cache from a call.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object( + cell, + "read_cache", + return_value=simple_dataframeresult, + ) as mocked_read_cache, patch.object( + cache, + "is_ipython", + return_value=True, + ) as mocked_ipython: + # "Real" tests are managed in the cell.read_cache test" + read_cache(search_params=params, cache_path=str(tmp_path), name="name") + + check.is_true(mocked_ipython.called) + check.equal(mocked_ipython.call_count, 1) + + check.is_true(mocked_read_cache.called) + check.equal(mocked_read_cache.call_count, 1) + + +def test_read_cache_cell_cache_not_found( + tmp_path: Path, + simple_dataframeresult: QueryResult, +) -> None: + """Test method to read cache from a cell.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object( + cell, + "read_cache", + side_effect=ValueError, + ) as mocked_read_cache_cell, patch.object( + file, + "read_cache", + return_value=simple_dataframeresult, + ) as mocked_read_cache_file, patch.object( + cache, + "is_ipython", + return_value=True, + ) as mocked_ipython, patch.object( + cell, + "write_cache", + ) as mocked_write_cache_cell: + # "Real" tests are managed in the cell.read_cache test" + read_cache(search_params=params, cache_path=str(tmp_path), name="name") + + check.is_true(mocked_ipython.called) + check.equal(mocked_ipython.call_count, 2) + + check.is_true(mocked_read_cache_cell.called) + check.equal(mocked_read_cache_cell.call_count, 1) + + # When reading from a cell, if the content is not found a failover to a file is attempted + check.is_true(mocked_read_cache_file.called) + check.equal(mocked_read_cache_file.call_count, 1) + + # Additionally, the cell cache must be re-written + check.is_true(mocked_write_cache_cell.called) + check.equal(mocked_write_cache_cell.call_count, 1) + + +def test_read_cache_file( + tmp_path: Path, + simple_dataframeresult: QueryResult, +) -> None: + """Test method to read cache from a file.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object( + file, "read_cache", return_value=simple_dataframeresult + ), patch.object( + cache, + "is_ipython", + return_value=False, + ): + # "Real" tests are managed in the file.read_cache test" + read_cache(search_params=params, cache_path=str(tmp_path), name="name") + + +def test_read_cache_file_not_exist(tmp_path: Path) -> None: + """Test method to read cache from a non existing file.""" + params: dict[str, Any] = {"key": "digest"} + + with patch.object(file, "read_cache", side_effect=FileNotFoundError), patch.object( + cache, + "is_ipython", + return_value=False, + ), pytest.raises(ValueError, match="Could not read from cache."): + read_cache(params, str(tmp_path), "name") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..d7385f19f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +"""Pytest configuration for tests module.""" + +from .fixtures import ( # noqa: F401 # pylint: disable=W0611 + generate_sample_data, + generate_simple_dataframe, + generate_simple_dataframeresult, +) diff --git a/tests/datamodel/test_result.py b/tests/datamodel/test_result.py new file mode 100644 index 000000000..5abbf570a --- /dev/null +++ b/tests/datamodel/test_result.py @@ -0,0 +1,59 @@ +"""Tests for Query Result datamodel.""" + +import pandas as pd +import pytest_check as check + +from msticpy.datamodel.result import QueryResult + + +def test_normalizer(simple_dataframeresult: QueryResult) -> None: + """Test attribute normalizer from DataFrameResult.""" + check.is_instance(simple_dataframeresult.normalizer, str) + check.equal(simple_dataframeresult.normalizer, "QueryResult") + + +def test_total_results(simple_dataframeresult: QueryResult) -> None: + """Test attribute total_results from DataFrameResult.""" + check.is_instance(simple_dataframeresult.total_results, int) + check.greater_equal(simple_dataframeresult.total_results, 0) + + +def test_results(simple_dataframeresult: QueryResult) -> None: + """Test attribute results from DataFrameResult.""" + check.is_instance(simple_dataframeresult.results, list) + check.greater_equal(len(simple_dataframeresult.results), 0) + + +def test__repr_markdown_(simple_dataframeresult: QueryResult) -> None: + """Test attribute _repr_markdown_ from DataFrameResult.""" + res: str = ( + QueryResult._repr_markdown_( # noqa: SLF001 #pylint: disable=protected-access + simple_dataframeresult + ) + ) + check.is_instance(res, str) + check.greater_equal(len(res), 0) + + +def test__repr_html_(simple_dataframeresult: QueryResult) -> None: + """Test attribute _repr_html_ from DataFrameResult.""" + res: str = ( + QueryResult._repr_html_( # noqa: SLF001 #pylint: disable=protected-access + simple_dataframeresult + ) + ) + check.is_instance(res, str) + check.greater_equal(len(res), 0) + + +def test__eq_(simple_dataframeresult: QueryResult) -> None: + """Test attribute _repr_html_ from DataFrameResult.""" + check.equal(simple_dataframeresult, simple_dataframeresult) + other_sample: QueryResult = QueryResult( + name="my name", + query="my query", + raw_results=pd.DataFrame(), + arguments={}, + ) + check.not_equal(simple_dataframeresult, other_sample) + check.not_equal(simple_dataframeresult, 42) diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 000000000..f5788001b --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,58 @@ +"""Fixture for testing msticpy.""" + +import datetime as dt +from typing import Union + +import pandas as pd +import pytest + +from msticpy.datamodel.result import QueryResult + + +@pytest.fixture( + params=[ + "", + "TeSt", + "test", + 42, + 42.42, + ["A", "B", "c"], + [1, 2, 3], + { + "key_str": "Value", + "key_int": 42, + "key_list": ["A", "B", "c"], + "key_dict": {"A": 33, "B": "C"}, + }, + dt.datetime.now(tz=dt.timezone.utc), + ], + name="sample_data", +) +def generate_sample_data( + request: pytest.FixtureRequest, +) -> Union[str, int, float, list, dict, dt.datetime]: + """Return sample data for pattern matching.""" + return request.param + + +@pytest.fixture(name="simple_dataframe") +def generate_simple_dataframe( + sample_data: Union[str, float, list, dict, dt.datetime], +) -> pd.DataFrame: + """Sample dataframe to test get_raw_data.""" + return pd.DataFrame( + [ + {"key": "A", "value": sample_data}, + ], + ) + + +@pytest.fixture(name="simple_dataframeresult") +def generate_simple_dataframeresult(simple_dataframe: pd.DataFrame) -> QueryResult: + """Sample dataframeresult objects to test get_raw_data.""" + return QueryResult( + name="name", + query="no query", + raw_results=simple_dataframe, + arguments={}, + ) diff --git a/tests/test_pkg_imports.py b/tests/test_pkg_imports.py index fc794e012..7b4fe87cd 100644 --- a/tests/test_pkg_imports.py +++ b/tests/test_pkg_imports.py @@ -36,6 +36,7 @@ "kqlmagiccustom", "sumologic-sdk", "openpyxl", + "compress-pickle", }