Skip to content

Commit 3eec945

Browse files
FeiueFeiueKevinHuSh
authored
complete implementation of dataset SDK (infiniflow#2147)
### What problem does this PR solve? Complete implementation of dataset SDK. infiniflow#1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <[email protected]> Co-authored-by: Kevin Hu <[email protected]>
1 parent 7d44054 commit 3eec945

File tree

6 files changed

+259
-85
lines changed

6 files changed

+259
-85
lines changed

api/apps/sdk/dataset.py

+118-44
Original file line numberDiff line numberDiff line change
@@ -15,82 +15,156 @@
1515
#
1616
from flask import request
1717

18-
from api.db import StatusEnum
19-
from api.db.db_models import APIToken
18+
from api.db import StatusEnum, FileSource
19+
from api.db.db_models import File
20+
from api.db.services.document_service import DocumentService
21+
from api.db.services.file2document_service import File2DocumentService
22+
from api.db.services.file_service import FileService
2023
from api.db.services.knowledgebase_service import KnowledgebaseService
2124
from api.db.services.user_service import TenantService
2225
from api.settings import RetCode
2326
from api.utils import get_uuid
24-
from api.utils.api_utils import get_data_error_result
25-
from api.utils.api_utils import get_json_result
27+
from api.utils.api_utils import get_json_result, token_required, get_data_error_result
2628

2729

2830
@manager.route('/save', methods=['POST'])
29-
def save():
31+
@token_required
32+
def save(tenant_id):
3033
req = request.json
31-
token = request.headers.get('Authorization').split()[1]
32-
objs = APIToken.query(token=token)
33-
if not objs:
34-
return get_json_result(
35-
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
36-
tenant_id = objs[0].tenant_id
3734
e, t = TenantService.get_by_id(tenant_id)
38-
if not e:
39-
return get_data_error_result(retmsg="Tenant not found.")
4035
if "id" not in req:
36+
if "tenant_id" in req or "embd_id" in req:
37+
return get_data_error_result(
38+
retmsg="Tenant_id or embedding_model must not be provided")
39+
if "name" not in req:
40+
return get_data_error_result(
41+
retmsg="Name is not empty!")
4142
req['id'] = get_uuid()
4243
req["name"] = req["name"].strip()
4344
if req["name"] == "":
4445
return get_data_error_result(
45-
retmsg="Name is not empty")
46-
if KnowledgebaseService.query(name=req["name"]):
46+
retmsg="Name is not empty string!")
47+
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
4748
return get_data_error_result(
48-
retmsg="Duplicated knowledgebase name")
49+
retmsg="Duplicated knowledgebase name in creating dataset.")
4950
req["tenant_id"] = tenant_id
5051
req['created_by'] = tenant_id
5152
req['embd_id'] = t.embd_id
5253
if not KnowledgebaseService.save(**req):
53-
return get_data_error_result(retmsg="Data saving error")
54-
req.pop('created_by')
55-
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
56-
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
57-
for old_key,new_key in keys_to_rename.items():
58-
if old_key in req:
59-
req[new_key]=req.pop(old_key)
54+
return get_data_error_result(retmsg="Create dataset error.(Database error)")
6055
return get_json_result(data=req)
6156
else:
62-
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
63-
return get_data_error_result(
64-
retmsg="Can't change tenant_id or embedding_model")
57+
if "tenant_id" in req:
58+
if req["tenant_id"] != tenant_id:
59+
return get_data_error_result(
60+
retmsg="Can't change tenant_id.")
6561

66-
e, kb = KnowledgebaseService.get_by_id(req["id"])
67-
if not e:
68-
return get_data_error_result(
69-
retmsg="Can't find this knowledgebase!")
62+
if "embd_id" in req:
63+
if req["embd_id"] != t.embd_id:
64+
return get_data_error_result(
65+
retmsg="Can't change embedding_model.")
7066

7167
if not KnowledgebaseService.query(
7268
created_by=tenant_id, id=req["id"]):
7369
return get_json_result(
74-
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
70+
data=False, retmsg='You do not own the dataset.',
7571
retcode=RetCode.OPERATING_ERROR)
7672

77-
if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
78-
return get_data_error_result(
79-
retmsg="Can't change document_count or chunk_count ")
73+
e, kb = KnowledgebaseService.get_by_id(req["id"])
8074

81-
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
82-
return get_data_error_result(
83-
retmsg="if chunk count is not 0, parser method is not changable. ")
75+
if "chunk_num" in req:
76+
if req["chunk_num"] != kb.chunk_num:
77+
return get_data_error_result(
78+
retmsg="Can't change chunk_count.")
8479

80+
if "doc_num" in req:
81+
if req['doc_num'] != kb.doc_num:
82+
return get_data_error_result(
83+
retmsg="Can't change document_count.")
8584

86-
if req["name"].lower() != kb.name.lower() \
87-
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
88-
status=StatusEnum.VALID.value)) > 0:
89-
return get_data_error_result(
90-
retmsg="Duplicated knowledgebase name.")
85+
if "parser_id" in req:
86+
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
87+
return get_data_error_result(
88+
retmsg="if chunk count is not 0, parse method is not changable.")
89+
if "name" in req:
90+
if req["name"].lower() != kb.name.lower() \
91+
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
92+
status=StatusEnum.VALID.value)) > 0:
93+
return get_data_error_result(
94+
retmsg="Duplicated knowledgebase name in updating dataset.")
9195

9296
del req["id"]
93-
req['created_by'] = tenant_id
9497
if not KnowledgebaseService.update_by_id(kb.id, req):
95-
return get_data_error_result(retmsg="Data update error ")
98+
return get_data_error_result(retmsg="Update dataset error.(Database error)")
9699
return get_json_result(data=True)
100+
101+
102+
@manager.route('/delete', methods=['DELETE'])
103+
@token_required
104+
def delete(tenant_id):
105+
req = request.args
106+
kbs = KnowledgebaseService.query(
107+
created_by=tenant_id, id=req["id"])
108+
if not kbs:
109+
return get_json_result(
110+
data=False, retmsg='You do not own the dataset',
111+
retcode=RetCode.OPERATING_ERROR)
112+
113+
for doc in DocumentService.query(kb_id=req["id"]):
114+
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
115+
return get_data_error_result(
116+
retmsg="Remove document error.(Database error)")
117+
f2d = File2DocumentService.get_by_document_id(doc.id)
118+
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
119+
File2DocumentService.delete_by_document_id(doc.id)
120+
121+
if not KnowledgebaseService.delete_by_id(req["id"]):
122+
return get_data_error_result(
123+
retmsg="Delete dataset error.(Database error)")
124+
return get_json_result(data=True)
125+
126+
127+
@manager.route('/list', methods=['GET'])
128+
@token_required
129+
def list_datasets(tenant_id):
130+
page_number = int(request.args.get("page", 1))
131+
items_per_page = int(request.args.get("page_size", 1024))
132+
orderby = request.args.get("orderby", "create_time")
133+
desc = bool(request.args.get("desc", True))
134+
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
135+
kbs = KnowledgebaseService.get_by_tenant_ids(
136+
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
137+
return get_json_result(data=kbs)
138+
139+
140+
@manager.route('/detail', methods=['GET'])
141+
@token_required
142+
def detail(tenant_id):
143+
req = request.args
144+
if "id" in req:
145+
id = req["id"]
146+
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
147+
if not kb:
148+
return get_json_result(
149+
data=False, retmsg='You do not own the dataset',
150+
retcode=RetCode.OPERATING_ERROR)
151+
if "name" in req:
152+
name = req["name"]
153+
if kb[0].name != name:
154+
return get_json_result(
155+
data=False, retmsg='You do not own the dataset',
156+
retcode=RetCode.OPERATING_ERROR)
157+
e, k = KnowledgebaseService.get_by_id(id)
158+
return get_json_result(data=k.to_dict())
159+
else:
160+
if "name" in req:
161+
name = req["name"]
162+
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
163+
if not e:
164+
return get_json_result(
165+
data=False, retmsg='You do not own the dataset',
166+
retcode=RetCode.OPERATING_ERROR)
167+
return get_json_result(data=k.to_dict())
168+
else:
169+
return get_data_error_result(
170+
retmsg="At least one of `id` or `name` must be provided.")

api/utils/api_utils.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,32 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import functools
1617
import json
1718
import random
1819
import time
20+
from base64 import b64encode
1921
from functools import wraps
22+
from hmac import HMAC
2023
from io import BytesIO
24+
from urllib.parse import quote, urlencode
25+
from uuid import uuid1
26+
27+
import requests
2128
from flask import (
2229
Response, jsonify, send_file, make_response,
2330
request as flask_request,
2431
)
2532
from werkzeug.http import HTTP_STATUS_CODES
2633

27-
from api.utils import json_dumps
28-
from api.settings import RetCode
34+
from api.db.db_models import APIToken
2935
from api.settings import (
3036
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
3137
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
3238
)
33-
import requests
34-
import functools
39+
from api.settings import RetCode
3540
from api.utils import CustomJSONEncoder
36-
from uuid import uuid1
37-
from base64 import b64encode
38-
from hmac import HMAC
39-
from urllib.parse import quote, urlencode
41+
from api.utils import json_dumps
4042

4143
requests.models.complexjson.dumps = functools.partial(
4244
json.dumps, cls=CustomJSONEncoder)
@@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
9698

9799
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
98100
data=None, job_id=None, meta=None):
99-
import re
100101
result_dict = {
101102
"retcode": retcode,
102103
"retmsg": retmsg,
@@ -145,7 +146,8 @@ def server_error_response(e):
145146
return get_json_result(
146147
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
147148
if repr(e).find("index_not_found_exception") >= 0:
148-
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
149+
return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
150+
retmsg="No chunk found, please upload file and parse it.")
149151

150152
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
151153

@@ -190,7 +192,9 @@ def decorated_function(*_args, **_kwargs):
190192
return get_json_result(
191193
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
192194
return func(*_args, **_kwargs)
195+
193196
return decorated_function
197+
194198
return wrapper
195199

196200

@@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
217221

218222

219223
def construct_response(retcode=RetCode.SUCCESS,
220-
retmsg='success', data=None, auth=None):
224+
retmsg='success', data=None, auth=None):
221225
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
222226
response_dict = {}
223227
for key, value in result_dict.items():
@@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS,
235239
response.headers["Access-Control-Expose-Headers"] = "Authorization"
236240
return response
237241

242+
238243
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
239244
import re
240245
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
@@ -263,7 +268,23 @@ def construct_error_response(e):
263268
pass
264269
if len(e.args) > 1:
265270
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
266-
if repr(e).find("index_not_found_exception") >=0:
267-
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
271+
if repr(e).find("index_not_found_exception") >= 0:
272+
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
273+
message="No chunk found, please upload file and parse it.")
268274

269275
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
276+
277+
278+
def token_required(func):
279+
@wraps(func)
280+
def decorated_function(*args, **kwargs):
281+
token = flask_request.headers.get('Authorization').split()[1]
282+
objs = APIToken.query(token=token)
283+
if not objs:
284+
return get_json_result(
285+
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
286+
)
287+
kwargs['tenant_id'] = objs[0].tenant_id
288+
return func(*args, **kwargs)
289+
290+
return decorated_function

sdk/python/ragflow/modules/base.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ def to_json(self):
1818
pr[name] = value
1919
return pr
2020

21-
2221
def post(self, path, param):
23-
res = self.rag.post(path,param)
22+
res = self.rag.post(path, param)
2423
return res
2524

26-
def get(self, path, params=''):
27-
res = self.rag.get(path,params)
25+
def get(self, path, params):
26+
res = self.rag.get(path, params)
2827
return res
2928

29+
def rm(self, path, params):
30+
res = self.rag.delete(path, params)
31+
return res
3032

33+
def __str__(self):
34+
return str(self.to_json())

sdk/python/ragflow/modules/dataset.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,36 @@ def __init__(self, rag, res_dict):
2121
self.permission = "me"
2222
self.document_count = 0
2323
self.chunk_count = 0
24-
self.parser_method = "naive"
24+
self.parse_method = "naive"
2525
self.parser_config = None
26+
for k in list(res_dict.keys()):
27+
if k == "embd_id":
28+
res_dict["embedding_model"] = res_dict[k]
29+
if k == "parser_id":
30+
res_dict['parse_method'] = res_dict[k]
31+
if k == "doc_num":
32+
res_dict["document_count"] = res_dict[k]
33+
if k == "chunk_num":
34+
res_dict["chunk_count"] = res_dict[k]
35+
if k not in self.__dict__:
36+
res_dict.pop(k)
2637
super().__init__(rag, res_dict)
2738

28-
def save(self):
39+
def save(self) -> bool:
2940
res = self.post('/dataset/save',
3041
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
3142
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
3243
"permission": self.permission,
33-
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
44+
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
3445
"parser_config": self.parser_config.to_json()
3546
})
3647
res = res.json()
37-
if not res.get("retmsg"): return True
38-
raise Exception(res["retmsg"])
48+
if res.get("retmsg") == "success": return True
49+
raise Exception(res["retmsg"])
50+
51+
def delete(self) -> bool:
52+
res = self.rm('/dataset/delete',
53+
{"id": self.id})
54+
res = res.json()
55+
if res.get("retmsg") == "success": return True
56+
raise Exception(res["retmsg"])

0 commit comments

Comments
 (0)