-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add-more-smb2-tests
- Loading branch information
Showing
2 changed files
with
106 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,146 @@ | ||
from smb.SMBConnection import SMBConnection # type: ignore | ||
from smb import smb_structs, base # type: ignore | ||
from smbprotocol.exceptions import SMBException # type: ignore | ||
import smbclient # type: ignore | ||
from smbprotocol.connection import Connection # type: ignore | ||
import typing | ||
import io | ||
import uuid | ||
|
||
rw_chunk_size = 1 << 21 # 2MB | ||
|
||
|
||
class SMBClient: | ||
"""Use pysmb to access the SMB server""" | ||
"""Use smbprotocol python module to access the SMB server""" | ||
|
||
def __init__(self, hostname: str, share: str, username: str, passwd: str): | ||
def __init__( | ||
self, | ||
hostname: str, | ||
share: str, | ||
username: str, | ||
passwd: str, | ||
port: int = 445, | ||
): | ||
self.server = hostname | ||
self.share = share | ||
self.username = username | ||
self.password = passwd | ||
self.port = port | ||
self.connection_cache: dict = {} | ||
self.client_params = { | ||
"username": username, | ||
"password": passwd, | ||
"connection_cache": self.connection_cache, | ||
} | ||
self.prepath = f"\\\\{self.server}\\{self.share}\\" | ||
self.connected = False | ||
self.connect() | ||
|
||
def _path(self, path: str = "/") -> str: | ||
path.replace("/", "\\") | ||
return self.prepath + path | ||
|
||
def connect(self) -> None: | ||
if self.connected: | ||
return | ||
try: | ||
self.ctx = SMBConnection( | ||
self.username, | ||
self.password, | ||
"smbclient", | ||
self.server, | ||
use_ntlm_v2=True, | ||
# Manually setup connection to avoid re-using guid through | ||
# the global configuration | ||
connection_key = f"{self.server.lower()}:{self.port}" | ||
connection = Connection(uuid.uuid4(), self.server, self.port) | ||
connection.connect() | ||
self.connection_cache[connection_key] = connection | ||
smbclient.register_session( | ||
self.server, port=self.port, **self.client_params | ||
) | ||
self.ctx.connect(self.server) | ||
self.connected = True | ||
except base.SMBTimeout as error: | ||
except SMBException as error: | ||
raise IOError(f"failed to connect: {error}") | ||
|
||
def disconnect(self) -> None: | ||
self.connected = False | ||
try: | ||
self.ctx.close() | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"disconnect: {error}") | ||
smbclient.reset_connection_cache( | ||
connection_cache=self.connection_cache | ||
) | ||
|
||
def _check_connected(self, action: str) -> None: | ||
if not self.connected: | ||
raise ConnectionError(f"{action}: server not connected") | ||
|
||
def listdir(self, path: str = "/") -> typing.List[str]: | ||
self._check_connected("listdir") | ||
try: | ||
dentries = self.ctx.listPath(self.share, path) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed to readdir: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"listdir: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"listdir: {error}") | ||
|
||
return [dent.filename for dent in dentries] | ||
filenames = smbclient.listdir( | ||
self._path(path), **self.client_params | ||
) | ||
except SMBException as error: | ||
raise IOError(f"listdir: {error}") | ||
return filenames | ||
|
||
def mkdir(self, dpath: str) -> None: | ||
self._check_connected("mkdir") | ||
if not self.connected: | ||
raise ConnectionError("listdir: server not connected") | ||
try: | ||
self.ctx.createDirectory(self.share, dpath) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed to mkdir: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"mkdir: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"mkdir: {error}") | ||
smbclient.mkdir(self._path(dpath), **self.client_params) | ||
except SMBException as error: | ||
raise IOError(f"mkdir: {error}") | ||
|
||
def rmdir(self, dpath: str) -> None: | ||
self._check_connected("rmdir") | ||
try: | ||
self.ctx.deleteDirectory(self.share, dpath) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed to rmdir: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"rmdir: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"rmdir: {error}") | ||
smbclient.rmdir(self._path(dpath), **self.client_params) | ||
except SMBException as error: | ||
raise IOError(f"rmdir: {error}") | ||
|
||
def unlink(self, fpath: str) -> None: | ||
self._check_connected("unlink") | ||
try: | ||
self.ctx.deleteFiles(self.share, fpath) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed to unlink: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"unlink: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"unlink: {error}") | ||
smbclient.remove(self._path(fpath), **self.client_params) | ||
except SMBException as error: | ||
raise IOError(f"unlink: {error}") | ||
|
||
def _read_write_fd(self, fd_from: typing.IO, fd_to: typing.IO) -> None: | ||
while True: | ||
data = fd_from.read(rw_chunk_size) | ||
if not data: | ||
break | ||
n = 0 | ||
while n < len(data): | ||
n += fd_to.write(data[n:]) | ||
|
||
def write(self, fpath: str, writeobj: typing.IO) -> None: | ||
self._check_connected("write") | ||
try: | ||
self.ctx.storeFile(self.share, fpath, writeobj) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed in write_text: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"write_text: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"write: {error}") | ||
with smbclient.open_file( | ||
self._path(fpath), mode="wb", **self.client_params | ||
) as fd: | ||
self._read_write_fd(writeobj, fd) | ||
except SMBException as error: | ||
raise IOError(f"write: {error}") | ||
|
||
def read(self, fpath: str, readobj: typing.IO) -> None: | ||
self._check_connected("read") | ||
try: | ||
self.ctx.retrieveFile(self.share, fpath, readobj) | ||
except smb_structs.OperationFailure as error: | ||
raise IOError(f"failed in read_text: {error}") | ||
except base.SMBTimeout as error: | ||
raise TimeoutError(f"read_text: {error}") | ||
except base.NotConnectedError as error: | ||
raise ConnectionError(f"read: {error}") | ||
with smbclient.open_file( | ||
self._path(fpath), mode="rb", **self.client_params | ||
) as fd: | ||
self._read_write_fd(fd, readobj) | ||
except SMBException as error: | ||
raise IOError(f"write: {error}") | ||
|
||
def write_text(self, fpath: str, teststr: str) -> None: | ||
with io.BytesIO(teststr.encode()) as writeobj: | ||
self.write(fpath, writeobj) | ||
self._check_connected("write_text") | ||
try: | ||
with smbclient.open_file( | ||
self._path(fpath), mode="w", **self.client_params | ||
) as fd: | ||
fd.write(teststr) | ||
except SMBException as error: | ||
raise IOError(f"write: {error}") | ||
|
||
def read_text(self, fpath: str) -> str: | ||
with io.BytesIO() as readobj: | ||
self.read(fpath, readobj) | ||
ret = readobj.getvalue().decode("utf8") | ||
self._check_connected("read_text") | ||
try: | ||
with smbclient.open_file( | ||
self._path(fpath), **self.client_params | ||
) as fd: | ||
ret = fd.read() | ||
except SMBException as error: | ||
raise IOError(f"write: {error}") | ||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters