diff --git a/.gitignore b/.gitignore index 73996cd..424c48b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Benches (created by `bench new`) benches/ +# Research +research-folder/ + # Python __pycache__/ *.py[cod] diff --git a/bench_cli/core/site.py b/bench_cli/core/site.py index 6a8167c..6ffead3 100644 --- a/bench_cli/core/site.py +++ b/bench_cli/core/site.py @@ -48,6 +48,9 @@ def create(self) -> None: cmd += ["--db-host", mariadb.host, "--db-port", str(mariadb.port)] if mariadb.root_password: cmd += ["--db-root-password", mariadb.root_password] + # Use '%' host scope for TCP connections to allow connections from any host. + # MySQL treats 'localhost' (unix socket) and '127.0.0.1' (TCP) as different hosts. + cmd += ["--mariadb-user-host-login-scope", "%"] run_command(cmd, cwd=self.bench.sites_path, stream_output=True) diff --git a/bench_cli/managers/mariadb_manager.py b/bench_cli/managers/mariadb_manager.py index 8fe1e02..262e176 100644 --- a/bench_cli/managers/mariadb_manager.py +++ b/bench_cli/managers/mariadb_manager.py @@ -39,6 +39,43 @@ def _apt_package(self) -> str: return f"mariadb-server-{self.config.version}" return "mariadb-server" + def _grant_host(self) -> str: + """Return the MySQL grant host for CREATE USER / GRANT statements. + + Use '%' when connecting over TCP so that the site DB user can reach + MariaDB from 127.0.0.1 (TCP loopback). Fall back to 'localhost' + only when a unix socket is detected — in MySQL's privilege tables + 'localhost' matches socket connections exclusively. + """ + return "localhost" if self._detect_socket() else "%" + + def create_user(self, username: str, password: str, db_name: str) -> None: + """Create the DB user and grant all privileges with the correct host scope. + + Frappe's new-site always creates the user as @'localhost', which + breaks TCP connections (127.0.0.1 != localhost in privilege tables). + This method uses the grant host derived from _grant_host() so TCP + connections work when no unix socket is available. + """ + grant_host = self._grant_host() + sql = ( + f"CREATE USER IF NOT EXISTS '{username}'@'{grant_host}'" + f" IDENTIFIED BY '{password}';" + f"GRANT ALL PRIVILEGES ON `{db_name}`.* TO '{username}'@'{grant_host}';" + "FLUSH PRIVILEGES;" + ) + mysql_bin = shutil.which("mariadb") or shutil.which("mysql") or "mysql" + cmd = [mysql_bin, f"-u{self.config.admin_user}"] + if self.config.root_password: + cmd.append(f"-p{self.config.root_password}") + socket_path = self._detect_socket() + if socket_path: + cmd.append(f"--socket={socket_path}") + else: + cmd += [f"-h{self.config.host}", f"-P{self.config.port}"] + cmd += ["-e", sql] + run_command(cmd) + def _detect_socket(self) -> str: if self.config.socket_path: return self.config.socket_path diff --git a/tests/test_mariadb_manager.py b/tests/test_mariadb_manager.py new file mode 100644 index 0000000..00b962f --- /dev/null +++ b/tests/test_mariadb_manager.py @@ -0,0 +1,150 @@ +"""Tests for MariaDBManager.create_user() and _grant_host().""" +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +import pytest + +from bench_cli.config.bench_config import BenchConfig +from bench_cli.config.mariadb_config import MariaDBConfig +from bench_cli.config.site_config import SiteConfig +from bench_cli.core.bench import Bench +from bench_cli.core.site import Site +from bench_cli.managers.mariadb_manager import MariaDBManager + + +def make_manager(host: str = "localhost", port: int = 3306, root_password: str = "secret") -> MariaDBManager: + return MariaDBManager(MariaDBConfig(host=host, port=port, root_password=root_password)) + + +# ── _grant_host() ───────────────────────────────────────────────────────────── + + +def test_grant_host_returns_percent_when_no_socket() -> None: + manager = make_manager() + with patch.object(manager, "_detect_socket", return_value=""): + assert manager._grant_host() == "%" + + +def test_grant_host_returns_localhost_when_socket_detected() -> None: + manager = make_manager() + with patch.object(manager, "_detect_socket", return_value="/tmp/mysql.sock"): + assert manager._grant_host() == "localhost" + + +# ── create_user() ───────────────────────────────────────────────────────────── + + +def test_create_user_tcp_uses_percent_host() -> None: + manager = make_manager(host="127.0.0.1", root_password="root") + with patch.object(manager, "_detect_socket", return_value=""): + with patch("shutil.which", side_effect=lambda b: "/usr/bin/mysql" if b == "mysql" else None): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert any("'mydb'@'%'" in part for part in called_cmd) + + +def test_create_user_socket_uses_localhost_host() -> None: + manager = make_manager(root_password="root") + with patch.object(manager, "_detect_socket", return_value="/tmp/mysql.sock"): + with patch("shutil.which", side_effect=lambda b: "/usr/bin/mysql" if b == "mysql" else None): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert any("'mydb'@'localhost'" in part for part in called_cmd) + + +def test_create_user_prefers_mariadb_binary() -> None: + manager = make_manager(root_password="root") + with patch.object(manager, "_detect_socket", return_value=""): + with patch("shutil.which", side_effect=lambda b: "/usr/bin/mariadb" if b == "mariadb" else None): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert called_cmd[0] == "/usr/bin/mariadb" + + +def test_create_user_falls_back_to_mysql_binary() -> None: + manager = make_manager(root_password="root") + with patch.object(manager, "_detect_socket", return_value=""): + with patch("shutil.which", side_effect=lambda b: "/usr/bin/mysql" if b == "mysql" else None): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert called_cmd[0] == "/usr/bin/mysql" + + +def test_create_user_tcp_passes_host_and_port() -> None: + manager = make_manager(host="127.0.0.1", port=3307, root_password="root") + with patch.object(manager, "_detect_socket", return_value=""): + with patch("shutil.which", return_value="/usr/bin/mysql"): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert "-h127.0.0.1" in called_cmd + assert "-P3307" in called_cmd + + +def test_create_user_socket_passes_socket_path() -> None: + manager = make_manager(root_password="root") + with patch.object(manager, "_detect_socket", return_value="/tmp/mysql.sock"): + with patch("shutil.which", return_value="/usr/bin/mysql"): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert "--socket=/tmp/mysql.sock" in called_cmd + + +def test_create_user_includes_root_password() -> None: + manager = make_manager(root_password="s3cr3t") + with patch.object(manager, "_detect_socket", return_value=""): + with patch("shutil.which", return_value="/usr/bin/mysql"): + with patch("bench_cli.managers.mariadb_manager.run_command") as mock_run: + manager.create_user("mydb", "mypass", "mydb") + called_cmd = mock_run.call_args[0][0] + assert "-ps3cr3t" in called_cmd + + +# ── site.create() with --mariadb-user-host-login-scope ────────────────────── + + +def make_bench(tmp_path: Path) -> Bench: + from bench_cli.config.app_config import AppConfig + from bench_cli.config.redis_config import RedisConfig + from bench_cli.config.worker_config import WorkerConfig + config = BenchConfig( + name="test-bench", + python_version="3.14", + apps=[AppConfig(name="frappe", repo="https://github.com/frappe/frappe", branch="version-16")], + mariadb=MariaDBConfig(host="127.0.0.1", port=3306, root_password="root"), + redis=RedisConfig(cache_port=13000, queue_port=11000, socketio_port=12000), + workers=__import__("bench_cli.config.worker_config", fromlist=["WorkerConfig"]).WorkerConfig(default_count=1, short_count=1, long_count=1), + ) + return Bench(config, tmp_path) + + +def test_site_create_passes_host_scope_for_tcp(tmp_path: Path) -> None: + bench = make_bench(tmp_path) + bench.create_directories() + site = Site(SiteConfig(name="site1.localhost", apps=["frappe"], admin_password="admin"), bench) + + with patch("bench_cli.core.site.run_command") as mock_run: + with patch("bench_cli.managers.mariadb_manager.MariaDBManager._detect_socket", return_value=""): + site.create() + called_cmd = mock_run.call_args[0][0] + assert "--mariadb-user-host-login-scope" in called_cmd + assert "%" in called_cmd + + +def test_site_create_skips_host_scope_for_socket(tmp_path: Path) -> None: + bench = make_bench(tmp_path) + bench.create_directories() + site = Site(SiteConfig(name="site1.localhost", apps=["frappe"], admin_password="admin"), bench) + + with patch("bench_cli.core.site.run_command") as mock_run: + with patch("bench_cli.managers.mariadb_manager.MariaDBManager._detect_socket", return_value="/tmp/mysql.sock"): + site.create() + called_cmd = mock_run.call_args[0][0] + assert "--mariadb-user-host-login-scope" not in called_cmd