From 0c570591b7fbac0f437bc3c6437f6da7f5beee07 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 11 Sep 2023 15:50:04 +0300 Subject: [PATCH] implemented: - update table (from dict or query) - insert to table (from query or dataframe) - delete from table --- mindsdb_sdk/database.py | 3 +- mindsdb_sdk/project.py | 7 +- mindsdb_sdk/query.py | 81 ------------------ mindsdb_sdk/table.py | 185 ++++++++++++++++++++++++++++++++++++++++ tests/test_sdk.py | 26 +++++- 5 files changed, 214 insertions(+), 88 deletions(-) create mode 100644 mindsdb_sdk/table.py diff --git a/mindsdb_sdk/database.py b/mindsdb_sdk/database.py index 00c5b65..7c2b744 100644 --- a/mindsdb_sdk/database.py +++ b/mindsdb_sdk/database.py @@ -4,7 +4,8 @@ from mindsdb_sql.parser.ast import Identifier, DropTables -from mindsdb_sdk.query import Query, Table +from .query import Query +from .table import Table from .objects_collection import MethodCollection class Database: diff --git a/mindsdb_sdk/project.py b/mindsdb_sdk/project.py index 123627a..7265bec 100644 --- a/mindsdb_sdk/project.py +++ b/mindsdb_sdk/project.py @@ -6,9 +6,10 @@ from mindsdb_sql.parser.dialects.mindsdb import CreatePredictor, CreateView, DropPredictor, CreateJob, DropJob from mindsdb_sql.parser.ast import DropView, Identifier, Delete, Star, Select -from mindsdb_sdk.utils import dict_to_binary_op -from mindsdb_sdk.model import Model, ModelVersion -from mindsdb_sdk.query import Query, View +from .utils import dict_to_binary_op +from .model import Model, ModelVersion +from .query import Query +from .table import View from .ml_engine import MLEngine from .objects_collection import MethodCollection diff --git a/mindsdb_sdk/query.py b/mindsdb_sdk/query.py index eea142f..2067cba 100644 --- a/mindsdb_sdk/query.py +++ b/mindsdb_sdk/query.py @@ -1,11 +1,6 @@ -import copy import pandas as pd -from mindsdb_sql.parser.ast import Select, Star, Identifier, Constant - -from mindsdb_sdk.utils import dict_to_binary_op - class Query: def __init__(self, api, sql, database=None): @@ -28,79 +23,3 @@ def fetch(self) -> pd.DataFrame: """ return self.api.sql_query(self.sql, self.database) - -class Table(Query): - def __init__(self, db, name): - super().__init__(db.api, '', db.name) - self.name = name - self.db = db - self._filters = {} - self._limit = None - self._update_query() - - def _filters_repr(self): - filters = '' - if len(filters) > 0: - filters = ', '.join( - f'{k}={v}' - for k, v in self._filters - ) - filters = ', ' + filters - return filters - - def __repr__(self): - return f'{self.__class__.__name__}({self.name}{self._filters_repr()})' - - def filter(self, **kwargs): - """ - Applies filters on table - table.filter(a=1, b=2) adds where condition to table: - 'select * from table1 where a=1 and b=2' - - :param kwargs: filter - """ - # creates new object - query = copy.deepcopy(self) - query._filters.update(kwargs) - query._update_query() - return query - - def limit(self, val: int): - """ - Applies limit condition to table query - - :param val: limit size - """ - query = copy.deepcopy(self) - query._limit = val - query._update_query() - return query - - def _update_query(self): - ast_query = Select( - targets=[Star()], - from_table=Identifier(self.name), - where=dict_to_binary_op(self._filters) - ) - if self._limit is not None: - ast_query.limit = Constant(self._limit) - self.sql = ast_query.to_string() - - -class View(Table): - # The same as table - pass - -# TODO getting view sql from api not implemented yet -# class View(Table): -# def __init__(self, api, data, project): -# super().__init__(api, data['name'], project) -# self.view_sql = data['sql'] -# -# def __repr__(self): -# # -# sql = self.view_sql.replace('\n', ' ') -# if len(sql) > 40: -# sql = sql[:37] + '...' -# -# return f'{self.__class__.__name__}({self.name}{self._filters_repr()}, sql={sql})' diff --git a/mindsdb_sdk/table.py b/mindsdb_sdk/table.py new file mode 100644 index 0000000..e21f097 --- /dev/null +++ b/mindsdb_sdk/table.py @@ -0,0 +1,185 @@ +import copy +from typing import Union + +import pandas as pd + +from mindsdb_sql.parser.ast import Select, Star, Identifier, Constant, Delete, Insert, Update + +from mindsdb_sdk.utils import dict_to_binary_op + +from .query import Query + +class Table(Query): + def __init__(self, db, name): + super().__init__(db.api, '', db.name) + self.name = name + self.db = db + self._filters = {} + self._limit = None + self._update_query() + + def _filters_repr(self): + filters = '' + if len(self._filters) > 0: + filters = ', '.join( + f'{k}={v}' + for k, v in self._filters.items() + ) + filters = ', ' + filters + return filters + + def __repr__(self): + limit_str = '' + if self._limit is not None: + limit_str = f'; limit={self._limit}' + return f'{self.__class__.__name__}({self.name}{self._filters_repr()}{limit_str})' + + def filter(self, **kwargs): + """ + Applies filters on table + table.filter(a=1, b=2) adds where condition to table: + 'select * from table1 where a=1 and b=2' + + :param kwargs: filter + """ + # creates new object + query = copy.deepcopy(self) + query._filters.update(kwargs) + query._update_query() + return query + + def limit(self, val: int): + """ + Applies limit condition to table query + + :param val: limit size + """ + query = copy.deepcopy(self) + query._limit = val + query._update_query() + return query + + def _update_query(self): + ast_query = Select( + targets=[Star()], + from_table=Identifier(self.name), + where=dict_to_binary_op(self._filters) + ) + if self._limit is not None: + ast_query.limit = Constant(self._limit) + self.sql = ast_query.to_string() + + def insert(self, query: Union[pd.DataFrame, Query]): + """ + Insert data from query of dataframe + :param query: dataframe of + :return: + """ + + if isinstance(query, pd.DataFrame): + # insert data + data_split = query.to_dict('split') + + ast_query = Insert( + table=Identifier(self.name), + columns=data_split['columns'], + values=data_split['data'] + ) + + sql = ast_query.to_string() + self.api.sql_query(sql, self.database) + else: + # insert from select + table = Identifier(parts=[self.database, self.name]) + self.api.sql_query( + f'INSERT INTO {table.to_string()} ({query.sql})', + database=query.database + ) + + def delete(self, **kwargs): + """ + Deletes record from table using filters table.delete(a=1, b=2) + + :param kwargs: filter + """ + identifier = Identifier(self.name) + # add database + identifier.parts.insert(0, self.database) + + ast_query = Delete( + table=identifier, + where=dict_to_binary_op(kwargs) + ) + + sql = ast_query.to_string() + self.api.sql_query(sql, 'mindsdb') + + def update(self, values: Union[dict, Query], on: list = None, filters: dict = None): + ''' + Update table by condition of from other table. + If 'values' is a dict: + - it will be an update by condition + - 'filters' is required + - used command: update table set a=1 where x=1 + + If 'values' is a Query: + - it will be an update from select + - 'on' is required + - used command: update table on a,b from (query) + + :param values: input for update, can be dict or query + :param on: list of column to map subselect to table ['a', 'b', ...] + :param filters: dict to filter updated rows, {'column': 'value', ...} + ''' + + if isinstance(values, Query): + # is update from select + if on is None: + raise ValueError('"on" parameter is required for update from query') + + # insert from select + table = Identifier(parts=[self.database, self.name]) + map_cols = ', '.join(on) + self.api.sql_query( + f'UPDATE {table.to_string()} ON {map_cols} FROM ({values.sql})', + database=values.database + ) + elif isinstance(values, dict): + # is regular update + if filters is None: + raise ValueError('"filters" parameter is required for update') + + update_columns = { + k: Constant(v) + for k, v in values.items() + } + + ast_query = Update( + table=Identifier(self.name), + update_columns=update_columns, + where=dict_to_binary_op(filters) + ) + + sql = ast_query.to_string() + self.api.sql_query(sql, self.database) + else: + raise NotImplementedError + + +class View(Table): + # The same as table + pass + +# TODO getting view sql from api not implemented yet +# class View(Table): +# def __init__(self, api, data, project): +# super().__init__(api, data['name'], project) +# self.view_sql = data['sql'] +# +# def __repr__(self): +# # +# sql = self.view_sql.replace('\n', ' ') +# if len(sql) > 40: +# sql = sql[:37] + '...' +# +# return f'{self.__class__.__name__}({self.name}{self._filters_repr()}, sql={sql})' diff --git a/tests/test_sdk.py b/tests/test_sdk.py index f60f300..49edde2 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -172,7 +172,7 @@ def check_table(self, table, mock_post): table = table.filter(a=3, b='2') table = table.limit(3) table.fetch() - + str(table) check_sql_call(mock_post, f'SELECT * FROM {table.name} WHERE (a = 3) AND (b = \'2\') LIMIT 3') @@ -801,7 +801,6 @@ def check_project_models(self, project, database, mock_post): } ) - @patch('requests.Session.post') def check_project_models_versions(self, project, database, mock_post): # ----------- model version -------------- @@ -867,6 +866,28 @@ def check_database(self, database, mock_post): assert table2.name == 't2' self.check_table(table2) + # -- insert into table -- + # from dataframe + table2.insert(pd.DataFrame([{'s': '1', 'x': 1}, {'s': 'a', 'x': 2}])) + check_sql_call(mock_post, "INSERT INTO t2(s, x) VALUES ('1', 1), ('a', 2)") + + # from query + table2.insert(query) + check_sql_call(mock_post, f"INSERT INTO {database.name}.t2 (select * from tbl1)") + + # -- delete in table -- + table2.delete(a=1, b='2') + check_sql_call(mock_post, f"DELETE FROM {database.name}.t2 WHERE (a = 1) AND (b = '2')") + + # -- update table -- + # from query + table2.update(query, on=['a', 'b']) + check_sql_call(mock_post, f"UPDATE {database.name}.t2 ON a, b FROM (select * from tbl1)") + + # from dict + table2.update({'a': '1', 'b': 1}, filters={'x': 3}) + check_sql_call(mock_post, f"UPDATE t2 SET a='1', b=1 WHERE x=3") + # create from table table1 = database.tables.t1 table1 = table1.filter(b=2) @@ -880,7 +901,6 @@ def check_database(self, database, mock_post): database.tables.drop('t3') check_sql_call(mock_post, f'drop table t3') - @patch('requests.Session.post') def check_project_jobs(self, project, mock_post):