Skip to content

Commit 2c771fb

Browse files
FeiueFeiue
and
Feiue
authored
Complete DataSet SDK implementation (#2171)
### What problem does this PR solve? Complete DataSet SDK implementation #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <[email protected]>
1 parent 667632b commit 2c771fb

File tree

3 files changed

+75
-33
lines changed

3 files changed

+75
-33
lines changed

api/apps/sdk/dataset.py

+71-21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
1617
from flask import request
1718

1819
from api.db import StatusEnum, FileSource
@@ -33,7 +34,7 @@ def save(tenant_id):
3334
req = request.json
3435
e, t = TenantService.get_by_id(tenant_id)
3536
if "id" not in req:
36-
if "tenant_id" in req or "embd_id" in req:
37+
if "tenant_id" in req or "embedding_model" in req:
3738
return get_data_error_result(
3839
retmsg="Tenant_id or embedding_model must not be provided")
3940
if "name" not in req:
@@ -47,22 +48,39 @@ def save(tenant_id):
4748
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
4849
return get_data_error_result(
4950
retmsg="Duplicated knowledgebase name in creating dataset.")
50-
req["tenant_id"] = tenant_id
51-
req['created_by'] = tenant_id
52-
req['embd_id'] = t.embd_id
51+
req["tenant_id"] = req['created_by'] = tenant_id
52+
req['embedding_model'] = t.embd_id
53+
key_mapping = {
54+
"chunk_num": "chunk_count",
55+
"doc_num": "document_count",
56+
"parser_id": "parse_method",
57+
"embd_id": "embedding_model"
58+
}
59+
mapped_keys = {new_key: req[old_key] for new_key, old_key in key_mapping.items() if old_key in req}
60+
req.update(mapped_keys)
5361
if not KnowledgebaseService.save(**req):
5462
return get_data_error_result(retmsg="Create dataset error.(Database error)")
55-
return get_json_result(data=req)
63+
renamed_data={}
64+
e, k = KnowledgebaseService.get_by_id(req["id"])
65+
for key, value in k.to_dict().items():
66+
new_key = key_mapping.get(key, key)
67+
renamed_data[new_key] = value
68+
return get_json_result(data=renamed_data)
5669
else:
70+
invalid_keys = {"embd_id", "chunk_num", "doc_num", "parser_id"}
71+
if any(key in req for key in invalid_keys):
72+
return get_data_error_result(retmsg="The input parameters are invalid.")
73+
5774
if "tenant_id" in req:
5875
if req["tenant_id"] != tenant_id:
5976
return get_data_error_result(
6077
retmsg="Can't change tenant_id.")
6178

62-
if "embd_id" in req:
63-
if req["embd_id"] != t.embd_id:
79+
if "embedding_model" in req:
80+
if req["embedding_model"] != t.embd_id:
6481
return get_data_error_result(
6582
retmsg="Can't change embedding_model.")
83+
req.pop("embedding_model")
6684

6785
if not KnowledgebaseService.query(
6886
created_by=tenant_id, id=req["id"]):
@@ -72,20 +90,23 @@ def save(tenant_id):
7290

7391
e, kb = KnowledgebaseService.get_by_id(req["id"])
7492

75-
if "chunk_num" in req:
76-
if req["chunk_num"] != kb.chunk_num:
93+
if "chunk_count" in req:
94+
if req["chunk_count"] != kb.chunk_num:
7795
return get_data_error_result(
7896
retmsg="Can't change chunk_count.")
97+
req.pop("chunk_count")
7998

80-
if "doc_num" in req:
81-
if req['doc_num'] != kb.doc_num:
99+
if "document_count" in req:
100+
if req['document_count'] != kb.doc_num:
82101
return get_data_error_result(
83102
retmsg="Can't change document_count.")
103+
req.pop("document_count")
84104

85-
if "parser_id" in req:
86-
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
105+
if "parse_method" in req:
106+
if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
87107
return get_data_error_result(
88-
retmsg="if chunk count is not 0, parse method is not changable.")
108+
retmsg="If chunk count is not 0, parse method is not changable.")
109+
req['parser_id'] = req.pop('parse_method')
89110
if "name" in req:
90111
if req["name"].lower() != kb.name.lower() \
91112
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
@@ -103,6 +124,9 @@ def save(tenant_id):
103124
@token_required
104125
def delete(tenant_id):
105126
req = request.args
127+
if "id" not in req:
128+
return get_data_error_result(
129+
retmsg="id is required")
106130
kbs = KnowledgebaseService.query(
107131
created_by=tenant_id, id=req["id"])
108132
if not kbs:
@@ -120,7 +144,7 @@ def delete(tenant_id):
120144

121145
if not KnowledgebaseService.delete_by_id(req["id"]):
122146
return get_data_error_result(
123-
retmsg="Delete dataset error.(Database error)")
147+
retmsg="Delete dataset error.(Database serror)")
124148
return get_json_result(data=True)
125149

126150

@@ -134,37 +158,63 @@ def list_datasets(tenant_id):
134158
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
135159
kbs = KnowledgebaseService.get_by_tenant_ids(
136160
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
137-
return get_json_result(data=kbs)
161+
renamed_list = []
162+
for kb in kbs:
163+
key_mapping = {
164+
"chunk_num": "chunk_count",
165+
"doc_num": "document_count",
166+
"parser_id": "parse_method",
167+
"embd_id": "embedding_model"
168+
}
169+
renamed_data = {}
170+
for key, value in kb.items():
171+
new_key = key_mapping.get(key, key)
172+
renamed_data[new_key] = value
173+
renamed_list.append(renamed_data)
174+
return get_json_result(data=renamed_list)
138175

139176

140177
@manager.route('/detail', methods=['GET'])
141178
@token_required
142179
def detail(tenant_id):
143180
req = request.args
181+
key_mapping = {
182+
"chunk_num": "chunk_count",
183+
"doc_num": "document_count",
184+
"parser_id": "parse_method",
185+
"embd_id": "embedding_model"
186+
}
187+
renamed_data = {}
144188
if "id" in req:
145189
id = req["id"]
146190
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
147191
if not kb:
148192
return get_json_result(
149-
data=False, retmsg='You do not own the dataset',
193+
data=False, retmsg='You do not own the dataset.',
150194
retcode=RetCode.OPERATING_ERROR)
151195
if "name" in req:
152196
name = req["name"]
153197
if kb[0].name != name:
154198
return get_json_result(
155-
data=False, retmsg='You do not own the dataset',
199+
data=False, retmsg='You do not own the dataset.',
156200
retcode=RetCode.OPERATING_ERROR)
157201
e, k = KnowledgebaseService.get_by_id(id)
158-
return get_json_result(data=k.to_dict())
202+
for key, value in k.to_dict().items():
203+
new_key = key_mapping.get(key, key)
204+
renamed_data[new_key] = value
205+
return get_json_result(data=renamed_data)
159206
else:
160207
if "name" in req:
161208
name = req["name"]
162209
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
163210
if not e:
164211
return get_json_result(
165-
data=False, retmsg='You do not own the dataset',
212+
data=False, retmsg='You do not own the dataset.',
166213
retcode=RetCode.OPERATING_ERROR)
167-
return get_json_result(data=k.to_dict())
214+
for key, value in k.to_dict().items():
215+
new_key = key_mapping.get(key, key)
216+
renamed_data[new_key] = value
217+
return get_json_result(data=renamed_data)
168218
else:
169219
return get_data_error_result(
170220
retmsg="At least one of `id` or `name` must be provided.")

sdk/python/ragflow/modules/dataset.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,16 @@ def __init__(self, rag, res_dict):
2424
self.parse_method = "naive"
2525
self.parser_config = None
2626
for k in list(res_dict.keys()):
27-
if k == "embd_id":
28-
res_dict["embedding_model"] = res_dict[k]
29-
if k == "parser_id":
30-
res_dict['parse_method'] = res_dict[k]
31-
if k == "doc_num":
32-
res_dict["document_count"] = res_dict[k]
33-
if k == "chunk_num":
34-
res_dict["chunk_count"] = res_dict[k]
3527
if k not in self.__dict__:
3628
res_dict.pop(k)
3729
super().__init__(rag, res_dict)
3830

3931
def save(self) -> bool:
4032
res = self.post('/dataset/save',
4133
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
42-
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
34+
"description": self.description, "language": self.language, "embedding_model": self.embedding_model,
4335
"permission": self.permission,
44-
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
36+
"document_count": self.document_count, "chunk_count": self.chunk_count, "parse_method": self.parse_method,
4537
"parser_config": self.parser_config.to_json()
4638
})
4739
res = res.json()

sdk/python/ragflow/ragflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def create_dataset(self, name: str, avatar: str = "", description: str = "", lan
5252
res = self.post("/dataset/save",
5353
{"name": name, "avatar": avatar, "description": description, "language": language,
5454
"permission": permission,
55-
"doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
55+
"document_count": document_count, "chunk_count": chunk_count, "parse_method": parse_method,
5656
"parser_config": parser_config
5757
}
5858
)
@@ -61,7 +61,7 @@ def create_dataset(self, name: str, avatar: str = "", description: str = "", lan
6161
return DataSet(self, res["data"])
6262
raise Exception(res["retmsg"])
6363

64-
def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \
64+
def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True) -> \
6565
List[DataSet]:
6666
res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
6767
res = res.json()

0 commit comments

Comments
 (0)