Skip to content

Commit

Permalink
Merge pull request #34871 from dimagi/cs/SC-3771-ds-unsubscribe-endpoint
Browse files Browse the repository at this point in the history
Add unsubscribe data source endpoint
  • Loading branch information
Charl1996 authored Jul 16, 2024
2 parents cb08c50 + bdb1679 commit d85707c
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
142 changes: 142 additions & 0 deletions corehq/apps/userreports/tests/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,145 @@ def test_subscribe_unsuccessful_with_missing_params(self):
request.content.decode("utf-8"),
"Missing parameters: client_id, token_url",
)


class TestUnsubscribeFromDataSource(TestCase):

DOMAIN = "test-domain"
CLIENT_ID = "client_id"
USERNAME = "username"

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.project = Domain.get_or_create_with_name(cls.DOMAIN, is_active=True)

cls.api_user_role = UserRole.create(
cls.DOMAIN, 'api-user', permissions=HqPermissions(access_api=True, view_reports=True)
)
cls.user = WebUser.create(cls.DOMAIN, cls.USERNAME, "password", None, None,
role_id=cls.api_user_role.get_id)
cls.api_key, _ = HQApiKey.objects.get_or_create(user=WebUser.get_django_user(cls.user))
cls.domain_api_key, _ = HQApiKey.objects.get_or_create(user=WebUser.get_django_user(cls.user),
name='domain-scoped',
domain=cls.DOMAIN)

@classmethod
def tearDownClass(cls):
cls.user.delete(deleted_by_domain=cls.DOMAIN, deleted_by=None)
cls.project.delete()
super().tearDownClass()

def _construct_api_auth_header(self, api_key):
return f'ApiKey {self.USERNAME}:{api_key.key}'

def _post_request(self, domain, data_source_id, data=None, **extras):
path = reverse("unsubscribe_from_configurable_data_source", args=(domain, data_source_id,))
return self.client.post(path, data=data, **extras)

def _subscribe_to_datasource(self, datasource_id):
conn_settings, __ = ConnectionSettings.objects.update_or_create(
client_id=self.CLIENT_ID,
defaults={
'domain': self.DOMAIN,
'name': "testy",
'auth_type': "oauth2_client",
'client_secret': 'client_secret',
'url': "",
'token_url': 'token_url',
}
)
DataSourceRepeater.objects.create(
name=f"{datasource_id} name",
domain=self.DOMAIN,
data_source_id=datasource_id,
connection_settings_id=conn_settings.id,
)

@flag_enabled('SUPERSET_ANALYTICS')
@flag_enabled('API_THROTTLE_WHITELIST')
def test_basic_unsubscribe_successful(self):
data_source_id = "data_source_id"
self._subscribe_to_datasource(data_source_id)

conn_settings = ConnectionSettings.objects.get(client_id=self.CLIENT_ID)
connection_settings_id = conn_settings.id

repeaters = DataSourceRepeater.objects.filter(connection_settings_id=connection_settings_id)
self.assertEqual(repeaters.count(), 1)

response = self._post_request(
domain=self.DOMAIN,
data={"client_id": self.CLIENT_ID},
data_source_id=data_source_id,
HTTP_AUTHORIZATION=self._construct_api_auth_header(self.domain_api_key),
)
self.assertEqual(response.status_code, 200)

repeaters = DataSourceRepeater.objects.filter(connection_settings_id=connection_settings_id)
self.assertEqual(repeaters.count(), 0)
self.assertEqual(ConnectionSettings.objects.filter(id=connection_settings_id).count(), 0)

@flag_enabled('SUPERSET_ANALYTICS')
@flag_enabled('API_THROTTLE_WHITELIST')
def test_unsubscribe_when_multiple_repeaters(self):
data_source_id_1 = "data_source_id1"
data_source_id_2 = "data_source_id2"
self._subscribe_to_datasource(data_source_id_1)
self._subscribe_to_datasource(data_source_id_2)

conn_settings = ConnectionSettings.objects.get(client_id=self.CLIENT_ID)
connection_settings_id = conn_settings.id

repeaters = DataSourceRepeater.objects.filter(connection_settings_id=connection_settings_id)
self.assertEqual(repeaters.count(), 2)

response = self._post_request(
domain=self.DOMAIN,
data={"client_id": self.CLIENT_ID},
data_source_id=data_source_id_1,
HTTP_AUTHORIZATION=self._construct_api_auth_header(self.domain_api_key),
)
self.assertEqual(response.status_code, 200)

repeaters = DataSourceRepeater.objects.filter(connection_settings_id=connection_settings_id)
self.assertEqual(repeaters.count(), 1)
self.assertEqual(ConnectionSettings.objects.filter(id=connection_settings_id).count(), 1)

@flag_enabled('SUPERSET_ANALYTICS')
@flag_enabled('API_THROTTLE_WHITELIST')
def test_missing_client_id(self):
response = self._post_request(
domain=self.DOMAIN,
data={},
data_source_id='datasource_id',
HTTP_AUTHORIZATION=self._construct_api_auth_header(self.domain_api_key),
)
self.assertEqual(response.status_code, 422)
self.assertEqual(response.content.decode("utf-8"), "The client_id parameter is required")

@flag_enabled('SUPERSET_ANALYTICS')
@flag_enabled('API_THROTTLE_WHITELIST')
def test_invalid_client_id(self):
response = self._post_request(
domain=self.DOMAIN,
data={'client_id': 'client_id'},
data_source_id='datasource_id',
HTTP_AUTHORIZATION=self._construct_api_auth_header(self.domain_api_key),
)
self.assertEqual(response.status_code, 422)
self.assertEqual(response.content.decode("utf-8"), "Invalid client_id")

@flag_enabled('SUPERSET_ANALYTICS')
@flag_enabled('API_THROTTLE_WHITELIST')
def test_invalid_data_source_id(self):
self._subscribe_to_datasource('datasource_id')

response = self._post_request(
domain=self.DOMAIN,
data={'client_id': 'client_id'},
data_source_id='invalid_datasource_id',
HTTP_AUTHORIZATION=self._construct_api_auth_header(self.domain_api_key),
)
self.assertEqual(response.status_code, 422)
self.assertEqual(response.content.decode("utf-8"), "Invalid data source ID")
3 changes: 3 additions & 0 deletions corehq/apps/userreports/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
undelete_report,
update_report_description,
subscribe_to_data_source_changes,
unsubscribe_from_data_source,
)

urlpatterns = [
Expand Down Expand Up @@ -72,6 +73,8 @@
name='export_configurable_data_source'),
url(r'^data_sources/subscribe/(?P<config_id>[\w-]+)/$', subscribe_to_data_source_changes,
name='subscribe_to_configurable_data_source'),
url(r'^data_sources/unsubscribe/(?P<config_id>[\w-]+)/$', unsubscribe_from_data_source,
name='unsubscribe_from_configurable_data_source'),
url(r'^expression_debugger/$', ExpressionDebuggerView.as_view(),
name='expression_debugger'),
url(r'^data_source_debugger/$', DataSourceDebuggerView.as_view(),
Expand Down
41 changes: 41 additions & 0 deletions corehq/apps/userreports/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,47 @@ def subscribe_to_data_source_changes(request, domain, config_id):
return HttpResponse(status=201)


@csrf_exempt
@require_POST
@api_auth()
@require_permission(HqPermissions.view_reports)
@toggles.SUPERSET_ANALYTICS.required_decorator()
@api_throttle
def unsubscribe_from_data_source(request, domain, config_id):
if 'client_id' not in request.POST:
return HttpResponse(
status=422,
content="The client_id parameter is required",
)
client_id = request.POST['client_id']

try:
conn_settings = ConnectionSettings.objects.get(client_id=client_id)
except ConnectionSettings.DoesNotExist:
return HttpResponse(
status=422,
content="Invalid client_id"
)

repeater = DataSourceRepeater.objects.filter(
domain=domain,
connection_settings_id=conn_settings.id,
options={"data_source_id": config_id},
)
if not repeater.exists():
return HttpResponse(
status=422,
content="Invalid data source ID"
)
repeater.delete()
conn_settings.clear_caches()

if not conn_settings.used_by:
conn_settings.delete()

return HttpResponse(status=200)


def _get_report_filter(domain, report_id, filter_id):
report = get_report_config_or_404(report_id, domain)[0]
report_filter = report.get_ui_filter(filter_id)
Expand Down
6 changes: 6 additions & 0 deletions corehq/motech/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def soft_delete(self):
self.is_deleted = True
self.save()

def clear_caches(self):
try:
del self.used_by
except AttributeError:
pass


class RequestLog(models.Model):
"""
Expand Down

0 comments on commit d85707c

Please sign in to comment.