Skip to content

Add support for new user fields #1638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ env:
WEAVIATE_127: 1.27.14
WEAVIATE_128: 1.28.8
WEAVIATE_129: 1.29.1
WEAVIATE_130: 1.30.0-rc.0-c1830a7-amd64
WEAVIATE_130: preview-db-users-add-last-used-time-0184fce.amd64

jobs:
lint-and-format:
Expand Down
5 changes: 3 additions & 2 deletions integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,8 +1357,9 @@ def is_supported():
update()

config = collection.config.get()
assert config.properties[0].description == "Name of the person"
assert config.properties[1].description == "Age of the person"
props = {prop.name: prop for prop in config.properties}
assert props["name"].description == "Name of the person"
assert props["age"].description == "Age of the person"
else:
with pytest.raises(UnexpectedStatusCodeError):
update()
Expand Down
48 changes: 45 additions & 3 deletions integration/test_users.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import random
import pytest

Expand Down Expand Up @@ -82,16 +83,37 @@ def test_create_user_and_get(client_factory: ClientFactory) -> None:
if client._connection._weaviate_version.is_lower_than(1, 30, 0):
pytest.skip("This test requires Weaviate 1.30.0 or higher")

before = datetime.datetime.now(tz=datetime.timezone.utc)

randomUserName = "new-user" + str(random.randint(1, 1000))
apiKey = client.users.db.create(user_id=randomUserName)

after_creation = datetime.datetime.now(tz=datetime.timezone.utc)

with weaviate.connect_to_local(
port=RBAC_PORTS[0], grpc_port=RBAC_PORTS[1], auth_credentials=Auth.api_key(apiKey)
) as client2:
user = client2.users.get_my_user()
assert user.user_id == randomUserName

after_login = datetime.datetime.now(tz=datetime.timezone.utc)

user = client.users.db.get(user_id=randomUserName)
assert user.user_id == randomUserName
assert user.user_type == UserTypes.DB_DYNAMIC
assert user.last_used is None

user = client.users.db.get(user_id=randomUserName, include_last_used_at_time=True)
assert user.active
assert user.last_used is not None
assert user.last_used > after_creation
assert user.last_used < after_login

assert len(user.apikey_first_letters) == 3
assert user.apikey_first_letters == apiKey[:3]
assert user.created_at < after_creation
assert user.created_at > before

assert client.users.db.delete(user_id=randomUserName)


Expand Down Expand Up @@ -150,6 +172,11 @@ def test_de_activate(client_factory: ClientFactory) -> None:
) # second activation returns a conflict => false
user = client.users.db.get(user_id=randomUserName)
assert user.active
assert user.last_used is None

user = client.users.db.get(user_id=randomUserName, include_last_used_at_time=True)
assert user.active
assert user.last_used is not None

client.users.db.delete(user_id=randomUserName)

Expand Down Expand Up @@ -206,12 +233,27 @@ def test_list_all_users(client_factory: ClientFactory) -> None:
if client._connection._weaviate_version.is_lower_than(1, 30, 0):
pytest.skip("This test requires Weaviate 1.30.0 or higher")

before = datetime.datetime.now(tz=datetime.timezone.utc)

for i in range(5):
client.users.db.delete(user_id=f"list-all-user-{i}")
client.users.db.create(user_id=f"list-all-user-{i}")

users = client.users.db.list_all()
dynamic_users = [user for user in users if user.user_id.startswith("list-all-")]
assert len(dynamic_users) == 5
after = datetime.datetime.now(tz=datetime.timezone.utc)

for include in [True, False]:
users = client.users.db.list_all(include_last_used_at_time=include)
dynamic_users = [user for user in users if user.user_id.startswith("list-all-")]
assert len(dynamic_users) == 5
assert all(user.user_type == UserTypes.DB_DYNAMIC for user in dynamic_users)
assert all(user.active for user in dynamic_users)
assert all(len(user.apikey_first_letters) == 3 for user in dynamic_users)
assert all(user.created_at < after for user in dynamic_users)
assert all(user.created_at > before for user in dynamic_users)
if include:
assert all(user.last_used is not None for user in dynamic_users)
else:
assert all(user.last_used is None for user in dynamic_users)

for i in range(5):
client.users.db.delete(user_id=f"list-all-{i}")
3 changes: 3 additions & 0 deletions weaviate/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class WeaviateDBUserRoleNames(TypedDict):
groups: List[str]
active: bool
dbUserType: str
lastUsedAt: Optional[str]
createdAt: str
apikeyFirstLetters: Optional[str]


class _Action:
Expand Down
27 changes: 25 additions & 2 deletions weaviate/users/async_.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Dict, List, Literal, Union, overload
from weaviate.connect.v4 import ConnectionAsync
from weaviate.users.executor import _DeprecatedExecutor, _DBExecutor, _OIDCExecutor
Expand Down Expand Up @@ -42,8 +43,30 @@ class _UsersDBAsync(_DBExecutor[ConnectionAsync]):
async def rotate_key(self, *, user_id: str) -> str: ...
async def deactivate(self, *, user_id: str, revoke_key: bool = False) -> bool: ...
async def activate(self, *, user_id: str) -> bool: ...
async def get(self, *, user_id: str) -> UserDB: ...
async def list_all(self) -> List[UserDB]: ...
@overload
async def get(
self, *, user_id: str, include_last_used_at_time: Literal[True]
) -> UserDB[datetime.datetime]: ...
@overload
async def get(
self, *, user_id: str, include_last_used_at_time: Literal[False] = False
) -> UserDB[None]: ...
@overload
async def get(
self, *, user_id: str, include_last_used_at_time: bool = False
) -> Union[UserDB[None], UserDB[datetime.datetime]]: ...
@overload
async def list_all(
self, *, include_last_used_at_time: Literal[True]
) -> List[UserDB[datetime.datetime]]: ...
@overload
async def list_all(
self, *, include_last_used_at_time: Literal[False] = False
) -> List[UserDB[None]]: ...
@overload
async def list_all(
self, *, include_last_used_at_time: bool = False
) -> Union[List[UserDB[None]], List[UserDB[datetime.datetime]]]: ...

class _UsersAsync(_DeprecatedExecutor[ConnectionAsync]):
async def get_my_user(self) -> OwnUser: ...
Expand Down
58 changes: 47 additions & 11 deletions weaviate/users/executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Any, Dict, Generic, List, Optional, Union, cast

from httpx import Response
Expand All @@ -17,7 +18,7 @@
UserDB,
OwnUser,
)
from weaviate.util import _decode_json_response_dict
from weaviate.util import _datetime_from_weaviate_str, _decode_json_response_dict


class _BaseExecutor(Generic[ConnectionType]):
Expand Down Expand Up @@ -401,53 +402,88 @@ def resp(res: Response) -> bool:
status_codes=_ExpectedStatusCodes(ok_in=[200, 409], error="Deactivate user"),
)

def get(self, *, user_id: str) -> executor.Result[Optional[UserDB]]:
def get(
self, *, user_id: str, include_last_used_at_time: bool = False
) -> executor.Result[Optional[Union[UserDB[None], UserDB[datetime.datetime]]]]:
"""Get all information about an user.

Args:
user_id: The id of the user.
"""

def resp(res: Response) -> Optional[UserDB]:
def resp(res: Response) -> Optional[Union[UserDB[None], UserDB[datetime.datetime]]]:
if res.status_code == 404:
return None
parsed = _decode_json_response_dict(res, "Get user")
assert parsed is not None
return UserDB(
user_id=parsed["userId"],
role_names=parsed["roles"],
active=parsed["active"],
user_type=UserTypes(parsed["dbUserType"]),
user = cast(WeaviateDBUserRoleNames, parsed)
ret = UserDB(
user_id=user["userId"],
role_names=user["roles"],
active=user["active"],
user_type=UserTypes(user["dbUserType"]),
created_at=_datetime_from_weaviate_str(user["createdAt"]),
last_used=get_last_used_at_time(user=user) if include_last_used_at_time else None,
apikey_first_letters=get_api_key_first_letters(user=user),
)
if include_last_used_at_time:
return cast(UserDB[datetime.datetime], ret)
return cast(UserDB[None], ret)

return executor.execute(
response_callback=resp,
method=self._connection.get,
params={"includeLastUsedTime": include_last_used_at_time},
path=f"/users/db/{user_id}",
error_msg=f"Could not get user '{user_id}'",
status_codes=_ExpectedStatusCodes(ok_in=[200, 404], error="get user"),
)

def list_all(self) -> executor.Result[List[UserDB]]:
def list_all(
self, *, include_last_used_at_time: bool = False
) -> executor.Result[Union[List[UserDB[None]], List[UserDB[datetime.datetime]]]]:
"""List all DB users."""

def resp(res: Response) -> List[UserDB]:
def resp(res: Response) -> Union[List[UserDB[None]], List[UserDB[datetime.datetime]]]:
parsed = _decode_json_response_dict(res, "Get user")
assert parsed is not None
return [

ret = [
UserDB(
user_id=user["userId"],
role_names=user["roles"],
active=user["active"],
user_type=UserTypes(user["dbUserType"]),
created_at=_datetime_from_weaviate_str(user["createdAt"]),
last_used=(
get_last_used_at_time(user=user) if include_last_used_at_time else None
),
apikey_first_letters=get_api_key_first_letters(user=user),
)
for user in cast(List[WeaviateDBUserRoleNames], parsed)
]

if include_last_used_at_time:
return cast(List[UserDB[datetime.datetime]], ret)
return cast(List[UserDB[None]], ret)

return executor.execute(
response_callback=resp,
method=self._connection.get,
params={"includeLastUsedTime": include_last_used_at_time},
path="/users/db",
error_msg="Could not list all users",
status_codes=_ExpectedStatusCodes(ok_in=[200], error="list all users"),
)


def get_last_used_at_time(user: WeaviateDBUserRoleNames) -> datetime.datetime:
lastused = user.get("lastUsedAt", None)
if lastused is None:
return datetime.datetime.min
return _datetime_from_weaviate_str(lastused)


def get_api_key_first_letters(user: WeaviateDBUserRoleNames) -> str:
first_letters = user.get("apiKeyFirstLetters", "")
return first_letters if first_letters else ""
27 changes: 25 additions & 2 deletions weaviate/users/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Dict, List, Literal, Union, overload
from weaviate.connect.v4 import ConnectionSync
from weaviate.users.executor import _DeprecatedExecutor, _DBExecutor, _OIDCExecutor
Expand Down Expand Up @@ -42,8 +43,30 @@ class _UsersDB(_DBExecutor[ConnectionSync]):
def rotate_key(self, *, user_id: str) -> str: ...
def deactivate(self, *, user_id: str, revoke_key: bool = False) -> bool: ...
def activate(self, *, user_id: str) -> bool: ...
def get(self, *, user_id: str) -> UserDB: ...
def list_all(self) -> List[UserDB]: ...
@overload
def get(
self, *, user_id: str, include_last_used_at_time: Literal[True]
) -> UserDB[datetime.datetime]: ...
@overload
def get(
self, *, user_id: str, include_last_used_at_time: Literal[False] = False
) -> UserDB[None]: ...
@overload
def get(
self, *, user_id: str, include_last_used_at_time: bool = False
) -> Union[UserDB[None], UserDB[datetime.datetime]]: ...
@overload
def list_all(
self, *, include_last_used_at_time: Literal[True]
) -> List[UserDB[datetime.datetime]]: ...
@overload
def list_all(
self, *, include_last_used_at_time: Literal[False] = False
) -> List[UserDB[None]]: ...
@overload
def list_all(
self, *, include_last_used_at_time: bool = False
) -> Union[List[UserDB[None]], List[UserDB[datetime.datetime]]]: ...

class _Users(_DeprecatedExecutor[ConnectionSync]):
def get_my_user(self) -> OwnUser: ...
Expand Down
12 changes: 10 additions & 2 deletions weaviate/users/users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Dict, Final, List, Literal
import datetime
from typing import Dict, Final, Generic, List, Literal, TypeVar

from weaviate.rbac.models import (
Role,
Expand All @@ -24,10 +25,17 @@ class UserBase:
user_type: UserTypes


# generic type for UserDB
T = TypeVar("T")


@dataclass
class UserDB(UserBase):
class UserDB(UserBase, Generic[T]):
user_type: UserTypes
active: bool
created_at: datetime.datetime
last_used: T
apikey_first_letters: str


@dataclass
Expand Down