diff --git a/.github/workflows/mongodb_settings.py b/.github/workflows/mongodb_settings.py index f33d458a..bdcc696c 100644 --- a/.github/workflows/mongodb_settings.py +++ b/.github/workflows/mongodb_settings.py @@ -3,7 +3,7 @@ from django_mongodb_backend import parse_uri if mongodb_uri := os.getenv("MONGODB_URI"): - db_settings = parse_uri(mongodb_uri) + db_settings = parse_uri(mongodb_uri, db_name="dummy") # Workaround for https://github.com/mongodb-labs/mongo-orchestration/issues/268 if db_settings["USER"] and db_settings["PASSWORD"]: diff --git a/README.md b/README.md index 5ccb2fc5..3449b905 100644 --- a/README.md +++ b/README.md @@ -46,24 +46,17 @@ $ django-admin startproject example --template https://github.com/mongodb-labs/d ### Connect to the database -Navigate to your `example/settings.py` file and find the variable named -`DATABASES` Replace the `DATABASES` variable with this: +Navigate to your `example/settings.py` file and replace the `DATABASES` +setting like so: ```python DATABASES = { - "default": django_mongodb_backend.parse_uri(""), + "default": django_mongodb_backend.parse_uri( + "", db_name="example" + ), } ``` -The MongoDB `` must also specify a database for the -`parse_uri` function to work. -If not already included, make sure you provide a value for `` -in your URI as shown in the example below: -```bash -mongodb+srv://myDatabaseUser:D1fficultP%40ssw0rd@cluster0.example.mongodb.net/?retryWrites=true&w=majority -``` - - ### Run the server To verify that you installed Django MongoDB Backend and correctly configured your project, run the following command from your project root: ```bash diff --git a/django_mongodb_backend/utils.py b/django_mongodb_backend/utils.py index c389d93b..95e3b5a8 100644 --- a/django_mongodb_backend/utils.py +++ b/django_mongodb_backend/utils.py @@ -28,7 +28,7 @@ def check_django_compatability(): ) -def parse_uri(uri, conn_max_age=0, test=None): +def parse_uri(uri, *, db_name=None, conn_max_age=0, test=None): """ Convert the given uri into a dictionary suitable for Django's DATABASES setting. @@ -45,9 +45,12 @@ def parse_uri(uri, conn_max_age=0, test=None): host, port = nodelist[0] elif len(nodelist) > 1: host = ",".join([f"{host}:{port}" for host, port in nodelist]) + db_name = db_name or uri["database"] + if not db_name: + raise ImproperlyConfigured("You must provide the db_name parameter.") settings_dict = { "ENGINE": "django_mongodb_backend", - "NAME": uri["database"], + "NAME": db_name, "HOST": host, "PORT": port, "USER": uri.get("username"), @@ -55,6 +58,8 @@ def parse_uri(uri, conn_max_age=0, test=None): "OPTIONS": uri.get("options"), "CONN_MAX_AGE": conn_max_age, } + if "authSource" not in settings_dict["OPTIONS"] and uri["database"]: + settings_dict["OPTIONS"]["authSource"] = uri["database"] if test: settings_dict["TEST"] = test return settings_dict diff --git a/docs/source/ref/utils.rst b/docs/source/ref/utils.rst index 16a61165..a5fb8ff3 100644 --- a/docs/source/ref/utils.rst +++ b/docs/source/ref/utils.rst @@ -12,7 +12,7 @@ following parts can be considered stable. ``parse_uri()`` =============== -.. function:: parse_uri(uri, conn_max_age=0, test=None) +.. function:: parse_uri(uri, db_name=None, conn_max_age=0, test=None) Parses a MongoDB `connection string`_ into a dictionary suitable for Django's :setting:`DATABASES` setting. @@ -23,8 +23,11 @@ Example:: import django_mongodb_backend - MONGODB_URI = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/myDatabase?retryWrites=true&w=majority&tls=false" - DATABASES["default"] = django_mongodb_backend.parse_uri(MONGODB_URI) + MONGODB_URI = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/defaultauthdb?retryWrites=true&w=majority&tls=false" + DATABASES["default"] = django_mongodb_backend.parse_uri(MONGODB_URI, db_name="example") + +You must specify ``db_name`` (the :setting:`NAME` of your database) if the URI +doesn't specify ``defaultauthdb``. You can use the parameters to customize the resulting :setting:`DATABASES` setting: diff --git a/tests/backend_/utils/test_parse_uri.py b/tests/backend_/utils/test_parse_uri.py index c4d475f1..a2898359 100644 --- a/tests/backend_/utils/test_parse_uri.py +++ b/tests/backend_/utils/test_parse_uri.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pymongo +from django.core.exceptions import ImproperlyConfigured from django.test import SimpleTestCase from django_mongodb_backend import parse_uri @@ -12,11 +13,28 @@ def test_simple_uri(self): self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") self.assertEqual(settings_dict["NAME"], "myDatabase") self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + self.assertEqual(settings_dict["OPTIONS"], {"authSource": "myDatabase"}) - def test_no_database(self): - settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net") - self.assertIsNone(settings_dict["NAME"]) + def test_db_name(self): + settings_dict = parse_uri("mongodb://cluster0.example.mongodb.net/", db_name="myDatabase") + self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") + self.assertEqual(settings_dict["NAME"], "myDatabase") + self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + self.assertEqual(settings_dict["OPTIONS"], {}) + + def test_db_name_overrides_default_auth_db(self): + settings_dict = parse_uri( + "mongodb://cluster0.example.mongodb.net/default_auth_db", db_name="myDatabase" + ) + self.assertEqual(settings_dict["ENGINE"], "django_mongodb_backend") + self.assertEqual(settings_dict["NAME"], "myDatabase") self.assertEqual(settings_dict["HOST"], "cluster0.example.mongodb.net") + self.assertEqual(settings_dict["OPTIONS"], {"authSource": "default_auth_db"}) + + def test_no_database(self): + msg = "You must provide the db_name parameter." + with self.assertRaisesMessage(ImproperlyConfigured, msg): + parse_uri("mongodb://cluster0.example.mongodb.net") def test_srv_uri_with_options(self): uri = "mongodb+srv://my_user:my_password@cluster0.example.mongodb.net/my_database?retryWrites=true&w=majority" @@ -30,35 +48,46 @@ def test_srv_uri_with_options(self): self.assertEqual(settings_dict["PASSWORD"], "my_password") self.assertIsNone(settings_dict["PORT"]) self.assertEqual( - settings_dict["OPTIONS"], {"retryWrites": True, "w": "majority", "tls": True} + settings_dict["OPTIONS"], + {"authSource": "my_database", "retryWrites": True, "w": "majority", "tls": True}, ) def test_localhost(self): - settings_dict = parse_uri("mongodb://localhost") + settings_dict = parse_uri("mongodb://localhost/db") self.assertEqual(settings_dict["HOST"], "localhost") self.assertEqual(settings_dict["PORT"], 27017) def test_localhost_with_port(self): - settings_dict = parse_uri("mongodb://localhost:27018") + settings_dict = parse_uri("mongodb://localhost:27018/db") self.assertEqual(settings_dict["HOST"], "localhost") self.assertEqual(settings_dict["PORT"], 27018) def test_hosts_with_ports(self): - settings_dict = parse_uri("mongodb://localhost:27017,localhost:27018") + settings_dict = parse_uri("mongodb://localhost:27017,localhost:27018/db") self.assertEqual(settings_dict["HOST"], "localhost:27017,localhost:27018") self.assertEqual(settings_dict["PORT"], None) def test_hosts_without_ports(self): - settings_dict = parse_uri("mongodb://host1.net,host2.net") + settings_dict = parse_uri("mongodb://host1.net,host2.net/db") self.assertEqual(settings_dict["HOST"], "host1.net:27017,host2.net:27017") self.assertEqual(settings_dict["PORT"], None) + def test_auth_source_in_query_string(self): + settings_dict = parse_uri("mongodb://localhost/?authSource=auth", db_name="db") + self.assertEqual(settings_dict["NAME"], "db") + self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"}) + + def test_auth_source_in_query_string_overrides_defaultauthdb(self): + settings_dict = parse_uri("mongodb://localhost/db?authSource=auth") + self.assertEqual(settings_dict["NAME"], "db") + self.assertEqual(settings_dict["OPTIONS"], {"authSource": "auth"}) + def test_conn_max_age(self): - settings_dict = parse_uri("mongodb://localhost", conn_max_age=600) + settings_dict = parse_uri("mongodb://localhost/db", conn_max_age=600) self.assertEqual(settings_dict["CONN_MAX_AGE"], 600) def test_test_kwarg(self): - settings_dict = parse_uri("mongodb://localhost", test={"NAME": "test_db"}) + settings_dict = parse_uri("mongodb://localhost/db", test={"NAME": "test_db"}) self.assertEqual(settings_dict["TEST"], {"NAME": "test_db"}) def test_invalid_credentials(self):