Skip to content

Commit

Permalink
Merge branch 'main' into add-more-smb2-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spuiuk authored Jun 28, 2024
2 parents 3f7ca98 + bd02302 commit 82b8fa0
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 71 deletions.
171 changes: 103 additions & 68 deletions testhelper/smbclient.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,146 @@
from smb.SMBConnection import SMBConnection # type: ignore
from smb import smb_structs, base # type: ignore
from smbprotocol.exceptions import SMBException # type: ignore
import smbclient # type: ignore
from smbprotocol.connection import Connection # type: ignore
import typing
import io
import uuid

rw_chunk_size = 1 << 21 # 2MB


class SMBClient:
"""Use pysmb to access the SMB server"""
"""Use smbprotocol python module to access the SMB server"""

def __init__(self, hostname: str, share: str, username: str, passwd: str):
def __init__(
self,
hostname: str,
share: str,
username: str,
passwd: str,
port: int = 445,
):
self.server = hostname
self.share = share
self.username = username
self.password = passwd
self.port = port
self.connection_cache: dict = {}
self.client_params = {
"username": username,
"password": passwd,
"connection_cache": self.connection_cache,
}
self.prepath = f"\\\\{self.server}\\{self.share}\\"
self.connected = False
self.connect()

def _path(self, path: str = "/") -> str:
path.replace("/", "\\")
return self.prepath + path

def connect(self) -> None:
if self.connected:
return
try:
self.ctx = SMBConnection(
self.username,
self.password,
"smbclient",
self.server,
use_ntlm_v2=True,
# Manually setup connection to avoid re-using guid through
# the global configuration
connection_key = f"{self.server.lower()}:{self.port}"
connection = Connection(uuid.uuid4(), self.server, self.port)
connection.connect()
self.connection_cache[connection_key] = connection
smbclient.register_session(
self.server, port=self.port, **self.client_params
)
self.ctx.connect(self.server)
self.connected = True
except base.SMBTimeout as error:
except SMBException as error:
raise IOError(f"failed to connect: {error}")

def disconnect(self) -> None:
self.connected = False
try:
self.ctx.close()
except base.SMBTimeout as error:
raise TimeoutError(f"disconnect: {error}")
smbclient.reset_connection_cache(
connection_cache=self.connection_cache
)

def _check_connected(self, action: str) -> None:
if not self.connected:
raise ConnectionError(f"{action}: server not connected")

def listdir(self, path: str = "/") -> typing.List[str]:
self._check_connected("listdir")
try:
dentries = self.ctx.listPath(self.share, path)
except smb_structs.OperationFailure as error:
raise IOError(f"failed to readdir: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"listdir: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"listdir: {error}")

return [dent.filename for dent in dentries]
filenames = smbclient.listdir(
self._path(path), **self.client_params
)
except SMBException as error:
raise IOError(f"listdir: {error}")
return filenames

def mkdir(self, dpath: str) -> None:
self._check_connected("mkdir")
if not self.connected:
raise ConnectionError("listdir: server not connected")
try:
self.ctx.createDirectory(self.share, dpath)
except smb_structs.OperationFailure as error:
raise IOError(f"failed to mkdir: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"mkdir: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"mkdir: {error}")
smbclient.mkdir(self._path(dpath), **self.client_params)
except SMBException as error:
raise IOError(f"mkdir: {error}")

def rmdir(self, dpath: str) -> None:
self._check_connected("rmdir")
try:
self.ctx.deleteDirectory(self.share, dpath)
except smb_structs.OperationFailure as error:
raise IOError(f"failed to rmdir: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"rmdir: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"rmdir: {error}")
smbclient.rmdir(self._path(dpath), **self.client_params)
except SMBException as error:
raise IOError(f"rmdir: {error}")

def unlink(self, fpath: str) -> None:
self._check_connected("unlink")
try:
self.ctx.deleteFiles(self.share, fpath)
except smb_structs.OperationFailure as error:
raise IOError(f"failed to unlink: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"unlink: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"unlink: {error}")
smbclient.remove(self._path(fpath), **self.client_params)
except SMBException as error:
raise IOError(f"unlink: {error}")

def _read_write_fd(self, fd_from: typing.IO, fd_to: typing.IO) -> None:
while True:
data = fd_from.read(rw_chunk_size)
if not data:
break
n = 0
while n < len(data):
n += fd_to.write(data[n:])

def write(self, fpath: str, writeobj: typing.IO) -> None:
self._check_connected("write")
try:
self.ctx.storeFile(self.share, fpath, writeobj)
except smb_structs.OperationFailure as error:
raise IOError(f"failed in write_text: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"write_text: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"write: {error}")
with smbclient.open_file(
self._path(fpath), mode="wb", **self.client_params
) as fd:
self._read_write_fd(writeobj, fd)
except SMBException as error:
raise IOError(f"write: {error}")

def read(self, fpath: str, readobj: typing.IO) -> None:
self._check_connected("read")
try:
self.ctx.retrieveFile(self.share, fpath, readobj)
except smb_structs.OperationFailure as error:
raise IOError(f"failed in read_text: {error}")
except base.SMBTimeout as error:
raise TimeoutError(f"read_text: {error}")
except base.NotConnectedError as error:
raise ConnectionError(f"read: {error}")
with smbclient.open_file(
self._path(fpath), mode="rb", **self.client_params
) as fd:
self._read_write_fd(fd, readobj)
except SMBException as error:
raise IOError(f"write: {error}")

def write_text(self, fpath: str, teststr: str) -> None:
with io.BytesIO(teststr.encode()) as writeobj:
self.write(fpath, writeobj)
self._check_connected("write_text")
try:
with smbclient.open_file(
self._path(fpath), mode="w", **self.client_params
) as fd:
fd.write(teststr)
except SMBException as error:
raise IOError(f"write: {error}")

def read_text(self, fpath: str) -> str:
with io.BytesIO() as readobj:
self.read(fpath, readobj)
ret = readobj.getvalue().decode("utf8")
self._check_connected("read_text")
try:
with smbclient.open_file(
self._path(fpath), **self.client_params
) as fd:
ret = fd.read()
except SMBException as error:
raise IOError(f"write: {error}")
return ret
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ deps =
pyyaml
pytest-randomly
iso8601
pysmb
smbprotocol
commands = pytest -vrfEsxXpP testcases/

[testenv:pytest-unprivileged]
Expand All @@ -22,7 +22,7 @@ deps =
pyyaml
pytest-randomly
iso8601
pysmb
smbprotocol
commands = pytest -vrfEsxXpP -k 'not privileged' testcases/

[testenv:sanity]
Expand All @@ -31,7 +31,7 @@ deps =
pyyaml
pytest-randomly
iso8601
pysmb
smbprotocol
changedir = {toxinidir}
commands = pytest -vrfEsxXpP testcases/consistency

Expand Down

0 comments on commit 82b8fa0

Please sign in to comment.