Skip to content

Commit

Permalink
implemented:
Browse files Browse the repository at this point in the history
- update table (from dict or query)
- insert to table (from query or dataframe)
- delete from table
  • Loading branch information
ea-rus committed Sep 11, 2023
1 parent 2241a35 commit 0c57059
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 88 deletions.
3 changes: 2 additions & 1 deletion mindsdb_sdk/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from mindsdb_sql.parser.ast import Identifier, DropTables

from mindsdb_sdk.query import Query, Table
from .query import Query
from .table import Table
from .objects_collection import MethodCollection

class Database:
Expand Down
7 changes: 4 additions & 3 deletions mindsdb_sdk/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from mindsdb_sql.parser.dialects.mindsdb import CreatePredictor, CreateView, DropPredictor, CreateJob, DropJob
from mindsdb_sql.parser.ast import DropView, Identifier, Delete, Star, Select

from mindsdb_sdk.utils import dict_to_binary_op
from mindsdb_sdk.model import Model, ModelVersion
from mindsdb_sdk.query import Query, View
from .utils import dict_to_binary_op
from .model import Model, ModelVersion
from .query import Query
from .table import View
from .ml_engine import MLEngine
from .objects_collection import MethodCollection

Expand Down
81 changes: 0 additions & 81 deletions mindsdb_sdk/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import copy

import pandas as pd

from mindsdb_sql.parser.ast import Select, Star, Identifier, Constant

from mindsdb_sdk.utils import dict_to_binary_op


class Query:
def __init__(self, api, sql, database=None):
Expand All @@ -28,79 +23,3 @@ def fetch(self) -> pd.DataFrame:
"""
return self.api.sql_query(self.sql, self.database)


class Table(Query):
def __init__(self, db, name):
super().__init__(db.api, '', db.name)
self.name = name
self.db = db
self._filters = {}
self._limit = None
self._update_query()

def _filters_repr(self):
filters = ''
if len(filters) > 0:
filters = ', '.join(
f'{k}={v}'
for k, v in self._filters
)
filters = ', ' + filters
return filters

def __repr__(self):
return f'{self.__class__.__name__}({self.name}{self._filters_repr()})'

def filter(self, **kwargs):
"""
Applies filters on table
table.filter(a=1, b=2) adds where condition to table:
'select * from table1 where a=1 and b=2'
:param kwargs: filter
"""
# creates new object
query = copy.deepcopy(self)
query._filters.update(kwargs)
query._update_query()
return query

def limit(self, val: int):
"""
Applies limit condition to table query
:param val: limit size
"""
query = copy.deepcopy(self)
query._limit = val
query._update_query()
return query

def _update_query(self):
ast_query = Select(
targets=[Star()],
from_table=Identifier(self.name),
where=dict_to_binary_op(self._filters)
)
if self._limit is not None:
ast_query.limit = Constant(self._limit)
self.sql = ast_query.to_string()


class View(Table):
# The same as table
pass

# TODO getting view sql from api not implemented yet
# class View(Table):
# def __init__(self, api, data, project):
# super().__init__(api, data['name'], project)
# self.view_sql = data['sql']
#
# def __repr__(self):
# #
# sql = self.view_sql.replace('\n', ' ')
# if len(sql) > 40:
# sql = sql[:37] + '...'
#
# return f'{self.__class__.__name__}({self.name}{self._filters_repr()}, sql={sql})'
185 changes: 185 additions & 0 deletions mindsdb_sdk/table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import copy
from typing import Union

import pandas as pd

from mindsdb_sql.parser.ast import Select, Star, Identifier, Constant, Delete, Insert, Update

from mindsdb_sdk.utils import dict_to_binary_op

from .query import Query

class Table(Query):
def __init__(self, db, name):
super().__init__(db.api, '', db.name)
self.name = name
self.db = db
self._filters = {}
self._limit = None
self._update_query()

def _filters_repr(self):
filters = ''
if len(self._filters) > 0:
filters = ', '.join(
f'{k}={v}'
for k, v in self._filters.items()
)
filters = ', ' + filters
return filters

def __repr__(self):
limit_str = ''
if self._limit is not None:
limit_str = f'; limit={self._limit}'
return f'{self.__class__.__name__}({self.name}{self._filters_repr()}{limit_str})'

def filter(self, **kwargs):
"""
Applies filters on table
table.filter(a=1, b=2) adds where condition to table:
'select * from table1 where a=1 and b=2'
:param kwargs: filter
"""
# creates new object
query = copy.deepcopy(self)
query._filters.update(kwargs)
query._update_query()
return query

def limit(self, val: int):
"""
Applies limit condition to table query
:param val: limit size
"""
query = copy.deepcopy(self)
query._limit = val
query._update_query()
return query

def _update_query(self):
ast_query = Select(
targets=[Star()],
from_table=Identifier(self.name),
where=dict_to_binary_op(self._filters)
)
if self._limit is not None:
ast_query.limit = Constant(self._limit)
self.sql = ast_query.to_string()

def insert(self, query: Union[pd.DataFrame, Query]):
"""
Insert data from query of dataframe
:param query: dataframe of
:return:
"""

if isinstance(query, pd.DataFrame):
# insert data
data_split = query.to_dict('split')

ast_query = Insert(
table=Identifier(self.name),
columns=data_split['columns'],
values=data_split['data']
)

sql = ast_query.to_string()
self.api.sql_query(sql, self.database)
else:
# insert from select
table = Identifier(parts=[self.database, self.name])
self.api.sql_query(
f'INSERT INTO {table.to_string()} ({query.sql})',
database=query.database
)

def delete(self, **kwargs):
"""
Deletes record from table using filters table.delete(a=1, b=2)
:param kwargs: filter
"""
identifier = Identifier(self.name)
# add database
identifier.parts.insert(0, self.database)

ast_query = Delete(
table=identifier,
where=dict_to_binary_op(kwargs)
)

sql = ast_query.to_string()
self.api.sql_query(sql, 'mindsdb')

def update(self, values: Union[dict, Query], on: list = None, filters: dict = None):
'''
Update table by condition of from other table.
If 'values' is a dict:
- it will be an update by condition
- 'filters' is required
- used command: update table set a=1 where x=1
If 'values' is a Query:
- it will be an update from select
- 'on' is required
- used command: update table on a,b from (query)
:param values: input for update, can be dict or query
:param on: list of column to map subselect to table ['a', 'b', ...]
:param filters: dict to filter updated rows, {'column': 'value', ...}
'''

if isinstance(values, Query):
# is update from select
if on is None:
raise ValueError('"on" parameter is required for update from query')

# insert from select
table = Identifier(parts=[self.database, self.name])
map_cols = ', '.join(on)
self.api.sql_query(
f'UPDATE {table.to_string()} ON {map_cols} FROM ({values.sql})',
database=values.database
)
elif isinstance(values, dict):
# is regular update
if filters is None:
raise ValueError('"filters" parameter is required for update')

update_columns = {
k: Constant(v)
for k, v in values.items()
}

ast_query = Update(
table=Identifier(self.name),
update_columns=update_columns,
where=dict_to_binary_op(filters)
)

sql = ast_query.to_string()
self.api.sql_query(sql, self.database)
else:
raise NotImplementedError


class View(Table):
# The same as table
pass

# TODO getting view sql from api not implemented yet
# class View(Table):
# def __init__(self, api, data, project):
# super().__init__(api, data['name'], project)
# self.view_sql = data['sql']
#
# def __repr__(self):
# #
# sql = self.view_sql.replace('\n', ' ')
# if len(sql) > 40:
# sql = sql[:37] + '...'
#
# return f'{self.__class__.__name__}({self.name}{self._filters_repr()}, sql={sql})'
26 changes: 23 additions & 3 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def check_table(self, table, mock_post):
table = table.filter(a=3, b='2')
table = table.limit(3)
table.fetch()

str(table)
check_sql_call(mock_post, f'SELECT * FROM {table.name} WHERE (a = 3) AND (b = \'2\') LIMIT 3')


Expand Down Expand Up @@ -801,7 +801,6 @@ def check_project_models(self, project, database, mock_post):
}
)


@patch('requests.Session.post')
def check_project_models_versions(self, project, database, mock_post):
# ----------- model version --------------
Expand Down Expand Up @@ -867,6 +866,28 @@ def check_database(self, database, mock_post):
assert table2.name == 't2'
self.check_table(table2)

# -- insert into table --
# from dataframe
table2.insert(pd.DataFrame([{'s': '1', 'x': 1}, {'s': 'a', 'x': 2}]))
check_sql_call(mock_post, "INSERT INTO t2(s, x) VALUES ('1', 1), ('a', 2)")

# from query
table2.insert(query)
check_sql_call(mock_post, f"INSERT INTO {database.name}.t2 (select * from tbl1)")

# -- delete in table --
table2.delete(a=1, b='2')
check_sql_call(mock_post, f"DELETE FROM {database.name}.t2 WHERE (a = 1) AND (b = '2')")

# -- update table --
# from query
table2.update(query, on=['a', 'b'])
check_sql_call(mock_post, f"UPDATE {database.name}.t2 ON a, b FROM (select * from tbl1)")

# from dict
table2.update({'a': '1', 'b': 1}, filters={'x': 3})
check_sql_call(mock_post, f"UPDATE t2 SET a='1', b=1 WHERE x=3")

# create from table
table1 = database.tables.t1
table1 = table1.filter(b=2)
Expand All @@ -880,7 +901,6 @@ def check_database(self, database, mock_post):
database.tables.drop('t3')
check_sql_call(mock_post, f'drop table t3')


@patch('requests.Session.post')
def check_project_jobs(self, project, mock_post):

Expand Down

0 comments on commit 0c57059

Please sign in to comment.