diff --git a/app_users/models.py b/app_users/models.py index f9744492a..ddd0563dc 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -82,16 +82,14 @@ def _update_user_balance_in_txn(txn: Transaction): # avoid updating twice for same invoice return - with transaction.atomic(): - self.balance += amount - self.save() + obj = self.add_balance_direct(amount) # create invoice entry txn.create( invoice_ref, { "amount": amount, - "end_balance": self.balance, + "end_balance": obj.balance, "timestamp": datetime.datetime.utcnow(), **invoice_items, }, @@ -99,6 +97,13 @@ def _update_user_balance_in_txn(txn: Transaction): _update_user_balance_in_txn(db.get_client().transaction()) + @transaction.atomic + def add_balance_direct(self, amount): + obj: AppUser = self.__class__.objects.select_for_update().get(pk=self.pk) + obj.balance += amount + obj.save(update_fields=["balance"]) + return obj + def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": # copy data from firebase user self.uid = user.uid diff --git a/bots/migrations/0023_alter_savedrun_workflow.py b/bots/migrations/0023_alter_savedrun_workflow.py new file mode 100644 index 000000000..f49b22be7 --- /dev/null +++ b/bots/migrations/0023_alter_savedrun_workflow.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.1 on 2023-07-14 11:38 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0022_remove_botintegration_analysis_url_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="savedrun", + name="workflow", + field=models.IntegerField( + choices=[ + (1, "doc-search"), + (2, "doc-summary"), + (3, "google-gpt"), + (4, "video-bots"), + (5, "LipsyncTTS"), + (6, "TextToSpeech"), + (7, "asr"), + (8, "Lipsync"), + (9, "DeforumSD"), + (10, "CompareText2Img"), + (11, "text2audio"), + (12, "Img2Img"), + (13, "FaceInpainting"), + (14, "GoogleImageGen"), + (15, "compare-ai-upscalers"), + (16, "SEOSummary"), + (17, "EmailFaceInpainting"), + (18, "SocialLookupEmail"), + (19, "ObjectInpainting"), + (20, "ImageSegmentation"), + (21, "CompareLLM"), + (22, "ChyronPlant"), + (23, "LetterWriter"), + (24, "SmartGPT"), + (25, "QRCodeGenerator"), + ], + default=4, + ), + ), + ] diff --git a/bots/models.py b/bots/models.py index 60adbe1db..6b70874e1 100644 --- a/bots/models.py +++ b/bots/models.py @@ -50,6 +50,7 @@ class Workflow(models.IntegerChoices): CHYRONPLANT = (22, "ChyronPlant") LETTERWRITER = (23, "LetterWriter") SMARTGPT = (24, "SmartGPT") + QRCODE = (25, "QRCodeGenerator") def get_app_url(self, example_id: str, run_id: str, uid: str): """return the url to the gooey app""" diff --git a/bots/tests.py b/bots/tests.py index fdd2fc667..0bc887389 100644 --- a/bots/tests.py +++ b/bots/tests.py @@ -1,102 +1,78 @@ -from django.test import TestCase +import random + +from app_users.models import AppUser +from daras_ai_v2.functional import map_parallel from .models import ( Message, Conversation, BotIntegration, Platform, - Workflow, ConvoState, ) -from django.db import transaction -from django.contrib import messages CHATML_ROLE_USER = "user" CHATML_ROLE_ASSISSTANT = "assistant" -# python manage.py test - - -class MessageModelTest(TestCase): - - """def test_create_and_save_message(self): - - # Create a new conversation - conversation = Conversation.objects.create() - - # Create and save a new message - message = Message(content="Hello, world!", conversation=conversation) - message.save() - - # Retrieve all messages from the database - all_messages = Message.objects.all() - self.assertEqual(len(all_messages), 1) - - # Check that the message's content is correct - only_message = all_messages[0] - self.assertEqual(only_message, message) - - # Check the content - self.assertEqual(only_message.content, "Hello, world!")""" - - -class BotIntegrationTest(TestCase): - @classmethod - def setUpClass(cls): - super(BotIntegrationTest, cls).setUpClass() - cls.keepdb = True - - @transaction.atomic - def test_create_bot_integration_conversation_message(self): - # Create a new BotIntegration with WhatsApp as the platform - bot_integration = BotIntegration.objects.create( - name="My Bot Integration", - saved_run=None, - billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id - user_language="en", - show_feedback_buttons=True, - platform=Platform.WHATSAPP, - wa_phone_number="my_whatsapp_number", - wa_phone_number_id="my_whatsapp_number_id", - ) - - # Create a Conversation that uses the BotIntegration - conversation = Conversation.objects.create( - bot_integration=bot_integration, - state=ConvoState.INITIAL, - wa_phone_number="user_whatsapp_number", - ) - - # Create a User Message within the Conversation - message_u = Message.objects.create( - conversation=conversation, - role=CHATML_ROLE_USER, - content="What types of chilies can be grown in Mumbai?", - display_content="What types of chilies can be grown in Mumbai?", - ) - - # Create a Bot Message within the Conversation - message_b = Message.objects.create( - conversation=conversation, - role=CHATML_ROLE_ASSISSTANT, - content="Red, green, and yellow grow the best.", - display_content="Red, green, and yellow grow the best.", - ) - - # Assert that the User Message was created successfully - self.assertEqual(Message.objects.count(), 2) - self.assertEqual(message_u.conversation, conversation) - self.assertEqual(message_u.role, CHATML_ROLE_USER) - self.assertEqual( - message_u.content, "What types of chilies can be grown in Mumbai?" - ) - self.assertEqual( - message_u.display_content, "What types of chilies can be grown in Mumbai?" - ) - # Assert that the Bot Message was created successfully - self.assertEqual(message_b.conversation, conversation) - self.assertEqual(message_b.role, CHATML_ROLE_ASSISSTANT) - self.assertEqual(message_b.content, "Red, green, and yellow grow the best.") - self.assertEqual( - message_b.display_content, "Red, green, and yellow grow the best." - ) +def test_add_balance_direct(): + pk = AppUser.objects.create(balance=0, is_anonymous=False).pk + amounts = [[random.randint(-100, 10_000) for _ in range(100)] for _ in range(5)] + + def worker(amts): + user = AppUser.objects.get(pk=pk) + for amt in amts: + user.add_balance_direct(amt) + + map_parallel(worker, amounts) + + assert AppUser.objects.get(pk=pk).balance == sum(map(sum, amounts)) + + +def test_create_bot_integration_conversation_message(): + # Create a new BotIntegration with WhatsApp as the platform + bot_integration = BotIntegration.objects.create( + name="My Bot Integration", + saved_run=None, + billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id + user_language="en", + show_feedback_buttons=True, + platform=Platform.WHATSAPP, + wa_phone_number="my_whatsapp_number", + wa_phone_number_id="my_whatsapp_number_id", + ) + + # Create a Conversation that uses the BotIntegration + conversation = Conversation.objects.create( + bot_integration=bot_integration, + state=ConvoState.INITIAL, + wa_phone_number="user_whatsapp_number", + ) + + # Create a User Message within the Conversation + message_u = Message.objects.create( + conversation=conversation, + role=CHATML_ROLE_USER, + content="What types of chilies can be grown in Mumbai?", + display_content="What types of chilies can be grown in Mumbai?", + ) + + # Create a Bot Message within the Conversation + message_b = Message.objects.create( + conversation=conversation, + role=CHATML_ROLE_ASSISSTANT, + content="Red, green, and yellow grow the best.", + display_content="Red, green, and yellow grow the best.", + ) + + # Assert that the User Message was created successfully + assert Message.objects.count() == 2 + assert message_u.conversation == conversation + assert message_u.role == CHATML_ROLE_USER + assert message_u.content == "What types of chilies can be grown in Mumbai?" + assert message_u.display_content == "What types of chilies can be grown in Mumbai?" + + # Assert that the Bot Message was created successfully + assert message_b.conversation == conversation + assert message_b.role == CHATML_ROLE_ASSISSTANT + assert message_b.content == "Red, green, and yellow grow the best." + assert message_b.display_content == "Red, green, and yellow grow the best." diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b5c7d2006..faf9dd759 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -10,6 +10,7 @@ from celeryapp.celeryconfig import app from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage from gooey_ui.pubsub import realtime_push +from gooey_ui.state import set_query_params @app.task @@ -21,6 +22,7 @@ def gui_runner( uid: str, state: dict, channel: str, + query_params: dict = None, ): self = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) @@ -29,6 +31,7 @@ def gui_runner( yield_val = None error_msg = None url = self.app_url(run_id=run_id, uid=uid) + set_query_params(query_params or {}) def save(done=False): if done: diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index a4cd887a7..5aa6eaaa7 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -628,6 +628,7 @@ def _render_output_col(self, submitted: bool): uid=uid, state=st.session_state, channel=f"gooey-outputs/{self.doc_name}/{uid}/{run_id}", + query_params=gooey_get_query_params(), ) raise QueryParamsRedirectException( diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 92dcbbcda..47e2c3727 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -30,6 +30,9 @@ else: SECRET_KEY = config("SECRET_KEY") +# https://hashids.org/ +HASHIDS_SALT = config("HASHIDS_SALT", default="") + ALLOWED_HOSTS = ["*"] INTERNAL_IPS = ["127.0.0.1"] SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") @@ -48,6 +51,7 @@ "django.forms", # needed to override admin forms "django.contrib.admin", "app_users", + "url_shortener", ] MIDDLEWARE = [ diff --git a/pytest.ini b/pytest.ini index b78a83efe..64558d1cf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,4 @@ [pytest] -addopts = --tb=native -vv -n 8 +addopts = --tb=native -vv -n 8 --disable-warnings DJANGO_SETTINGS_MODULE = daras_ai_v2.settings python_files = tests.py test_*.py *_tests.py \ No newline at end of file diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 3f911e3c7..840641274 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -10,6 +10,8 @@ from pyzbar import pyzbar import gooey_ui as st +from app_users.models import AppUser +from bots.models import Workflow from daras_ai.image_input import ( upload_file_from_bytes, bytes_to_cv2_img, @@ -29,6 +31,7 @@ Img2ImgModels, Schedulers, ) +from url_shortener.models import ShortenedURL ATTEMPTS = 1 @@ -273,19 +276,47 @@ def render_settings(self): def render_output(self): state = st.session_state self._render_outputs(state) - st.caption(f'URL: {state.get("qr_code_data")}') - if state.get("shortened_url", False): - st.caption(f'Shortened: {state.get("shortened_url")}') + + def render_example(self, state: dict): + col1, col2 = st.columns(2) + with col1: + st.markdown( + f""" + ```text + {state.get("text_prompt", "")} + ``` + """ + ) + with col2: + self._render_outputs(state) def _render_outputs(self, state: dict): for img in state.get("output_images", []): st.image(img) + qr_code_data = state.get("qr_code_data") + if not qr_code_data: + continue + shortened_url = state.get("shortened_url") + if not shortened_url: + st.caption(qr_code_data) + continue + hashid = furl(shortened_url.strip("/")).path.segments[-1] + try: + clicks = ShortenedURL.objects.get_by_hashid(hashid).clicks + except ShortenedURL.DoesNotExist: + clicks = None + if clicks is not None: + st.caption(f"{shortened_url} → {qr_code_data} (Views: {clicks})") + else: + st.caption(f"{shortened_url} → {qr_code_data}") def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) yield "Generating QR Code..." - image, qr_code_data, did_shorten = generate_and_upload_qr_code(request) + image, qr_code_data, did_shorten = generate_and_upload_qr_code( + request, self.request.user + ) if did_shorten: state["shortened_url"] = qr_code_data state["cleaned_qr_code"] = image @@ -321,22 +352,6 @@ def run(self, state: dict) -> typing.Iterator[str | None]: # 'Doh! That didn\'t work. Sometimes the AI produces bad QR codes. Please press "Regenerate" to try again.' # ) - def render_example(self, state: dict): - col1, col2 = st.columns(2) - with col1: - st.markdown( - f""" - ```text - {state.get("text_prompt", "")} - ``` - """ - ) - st.caption(f'URL: {state.get("qr_code_data")}') - if state.get("shortened_url", False): - st.caption(f'Shortened: {state.get("shortened_url")}') - with col2: - self._render_outputs(state) - def preview_description(self, state: dict) -> str: return "Enter your URL (or text) and an image prompt and we'll generate an arty QR code with your artistic style and content in about 30 seconds. This is a rad way to advertise your website in IRL or print on a poster." @@ -351,7 +366,19 @@ def get_raw_price(self, state: dict) -> int: return total * state.get("num_outputs", 1) -def generate_and_upload_qr_code(request: QRCodeGeneratorPage.RequestModel): +def is_url(url: str) -> bool: + try: + URLValidator(schemes=["http", "https"])(url) + except ValidationError: + return False + else: + return True + + +def generate_and_upload_qr_code( + request: QRCodeGeneratorPage.RequestModel, + user: AppUser, +) -> tuple[str, str, bool]: qr_code_data = request.qr_code_data if not qr_code_data: qr_code_input_image = request.qr_code_input_image @@ -360,9 +387,13 @@ def generate_and_upload_qr_code(request: QRCodeGeneratorPage.RequestModel): qr_code_data = download_qr_code_data(qr_code_input_image) qr_code_data = qr_code_data.strip() - should_shorten = request.use_url_shortener and qr_code_data.startswith("http") - if should_shorten: - qr_code_data, should_shorten = shorten_url(qr_code_data) + shortened = request.use_url_shortener and is_url(qr_code_data) + if shortened: + qr_code_data = ShortenedURL.objects.get_or_create_for_workflow( + url=qr_code_data, + user=user, + workflow=Workflow.QRCODE, + )[0].shortened_url() img_cv2 = generate_qr_code(qr_code_data) @@ -377,7 +408,7 @@ def generate_and_upload_qr_code(request: QRCodeGeneratorPage.RequestModel): ) img_url = upload_file_from_bytes("cleaned_qr.png", cv2_img_to_bytes(img_cv2)) - return img_url, qr_code_data, should_shorten + return img_url, qr_code_data, shortened def generate_qr_code(qr_code_data: str) -> np.ndarray: @@ -386,29 +417,6 @@ def generate_qr_code(qr_code_data: str) -> np.ndarray: return np.array(qr.make_image().convert("RGB")) -def shorten_url(qr_code_data: str) -> tuple[str, bool]: - try: - r = requests.get( - furl( - "https://is.gd/create.php", - query_params={"format": "simple", "url": qr_code_data}, - ).url, - timeout=2.50, - ) - r.raise_for_status() - # Validate that the response is a URL - try: - URLValidator()(r.text) - except ValidationError: - pass - else: - return r.text, True - except requests.RequestException as e: - print(f"ignoring shortened url error: {e}") - pass # We can keep going without the shortened url and just use the original url - return qr_code_data, False - - def download_qr_code_data(url: str) -> str: r = requests.get(url) r.raise_for_status() diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index cfbedaadc..e9c7e098b 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1,5 +1,4 @@ import datetime -import json import os import os.path import re @@ -9,6 +8,7 @@ from pydantic import BaseModel import gooey_ui as st +from bots.models import Workflow from daras_ai.image_input import ( truncate_text_words, ) @@ -52,6 +52,7 @@ ) from recipes.Lipsync import LipsyncPage from recipes.TextToSpeech import TextToSpeechPage +from url_shortener.models import ShortenedURL BOT_SCRIPT_RE = re.compile( # start of line @@ -533,6 +534,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: references = yield from get_top_k_references( DocSearchPage.RequestModel.parse_obj(state) ) + for reference in references: + reference["url"] = ShortenedURL.objects.get_or_create_for_workflow( + url=reference["url"], + user=self.request.user, + workflow=Workflow.VIDEOBOTS, + )[0].shortened_url() state["references"] = references # if doc search is successful, add the search results to the user prompt if references: diff --git a/server.py b/server.py index 9e4bc63bc..fcc4b2efa 100644 --- a/server.py +++ b/server.py @@ -23,6 +23,7 @@ ) from daras_ai_v2 import settings from routers import billing, facebook, talkjs, api, root +import url_shortener.routers as url_shortener app = FastAPI(title="GOOEY.AI", docs_url=None, redoc_url="/docs") @@ -33,6 +34,7 @@ app.include_router(talkjs.router, include_in_schema=False) app.include_router(facebook.router, include_in_schema=False) app.include_router(root.app, include_in_schema=False) +app.include_router(url_shortener.app, include_in_schema=False) app.add_middleware( CORSMiddleware, diff --git a/url_shortener/__init__.py b/url_shortener/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/url_shortener/admin.py b/url_shortener/admin.py new file mode 100644 index 000000000..8b9a0385a --- /dev/null +++ b/url_shortener/admin.py @@ -0,0 +1,46 @@ +from django.contrib import admin + +from bots.admin import SavedRunAdmin, export_to_csv, export_to_excel +from bots.admin_links import list_related_html_url +from url_shortener import models + + +@admin.register(models.ShortenedURL) +class ShortenedURLAdmin(admin.ModelAdmin): + autocomplete_fields = ["user"] + list_filter = [ + "clicks", + "created_at", + ] + search_fields = ["url", "user"] + [ + f"saved_run__{field}" for field in SavedRunAdmin.search_fields + ] + list_display = [ + "url", + "user", + "shortened_url", + "clicks", + "get_max_clicks", + "disabled", + "created_at", + "updated_at", + ] + readonly_fields = [ + "clicks", + "created_at", + "updated_at", + "shortened_url", + "get_saved_runs", + ] + exclude = ["saved_runs"] + ordering = ["created_at"] + actions = [export_to_csv, export_to_excel] + + @admin.display(ordering="max_clicks", description="Max Clicks") + def get_max_clicks(self, obj): + max_clicks = obj.max_clicks + return max_clicks or "∞" + + @admin.display(description="Saved Runs") + def get_saved_runs(self, obj: models.ShortenedURL): + return list_related_html_url(obj.saved_runs, show_add=False) diff --git a/url_shortener/apps.py b/url_shortener/apps.py new file mode 100644 index 000000000..822d5f54e --- /dev/null +++ b/url_shortener/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class UrlShortenerConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "url_shortener" + verbose_name = "URL Shortner" diff --git a/url_shortener/migrations/0001_initial.py b/url_shortener/migrations/0001_initial.py new file mode 100644 index 000000000..695066608 --- /dev/null +++ b/url_shortener/migrations/0001_initial.py @@ -0,0 +1,77 @@ +# Generated by Django 4.2.1 on 2023-07-14 11:39 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("bots", "0023_alter_savedrun_workflow"), + ] + + operations = [ + migrations.CreateModel( + name="ShortenedURL", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("url", models.URLField()), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "clicks", + models.IntegerField( + default=0, help_text="The number of clicks on this url" + ), + ), + ( + "max_clicks", + models.IntegerField( + default=0, + help_text="The maximum number of clicks allowed. Set to 0 for no limit.", + ), + ), + ( + "disabled", + models.BooleanField( + default=False, help_text="Disable this shortened url" + ), + ), + ( + "saved_runs", + models.ManyToManyField( + blank=True, + help_text="The runs that are using this shortened url", + related_name="shortened_urls", + to="bots.savedrun", + ), + ), + ( + "user", + models.ForeignKey( + blank=True, + default=None, + help_text="The user that generated this shortened url", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="shortened_urls", + to="app_users.appuser", + ), + ), + ], + options={ + "verbose_name": "Shortened URL", + "ordering": ("-created_at",), + "get_latest_by": "created_at", + }, + ), + ] diff --git a/url_shortener/migrations/__init__.py b/url_shortener/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/url_shortener/models.py b/url_shortener/models.py new file mode 100644 index 000000000..c1d73442d --- /dev/null +++ b/url_shortener/models.py @@ -0,0 +1,135 @@ +import hashids +import pytz +from django.db import models, transaction, IntegrityError +from furl import furl + +from app_users.models import AppUser +from bots.models import Workflow, SavedRun +from daras_ai_v2 import settings +from daras_ai_v2.query_params import gooey_get_query_params +from daras_ai_v2.query_params_util import extract_query_params + + +class ShortenedURLQuerySet(models.QuerySet): + def get_or_create_for_workflow( + self, *, user: AppUser, url: str, workflow: Workflow + ) -> tuple["ShortenedURL", bool]: + surl, created = self.filter_first_or_create(url=url, user=user) + example_id, run_id, uid = extract_query_params( + gooey_get_query_params(), default="" + ) + surl.saved_runs.add( + SavedRun.objects.get_or_create( + workflow=workflow, example_id=example_id, run_id=run_id, uid=uid + )[0], + ) + return surl, created + + def filter_first_or_create(self, defaults=None, **kwargs): + """ + Look up an object with the given kwargs, creating one if necessary. + Return a tuple of (object, created), where created is a boolean + specifying whether an object was created. + """ + # The get() needs to be targeted at the write database in order + # to avoid potential transaction consistency problems. + self._for_write = True + try: + return self.filter(**kwargs)[0], False + except IndexError: + params = self._extract_model_params(defaults, **kwargs) + # Try to create an object using passed params. + try: + with transaction.atomic(using=self.db): + # params = dict(resolve_callables(params)) + return self.create(**params), True + except IntegrityError: + try: + return self.filter(**kwargs)[0], False + except IndexError: + pass + raise + + def get_by_hashid(self, hashid: str) -> "ShortenedURL": + try: + obj_id = _hashids.decode(hashid)[0] + except IndexError as e: + raise self.model.DoesNotExist from e + else: + return self.get(id=obj_id) + + def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": + import pandas as pd + + qs = self.all().prefetch_related("saved_run") + rows = [] + for surl in qs[:1000]: + surl: ShortenedURL + rows.append( + { + "ID": surl.id, + "URL": surl.url, + "SHORTENED_URL": surl.shortened_url(), + "CREATED_AT": surl.created_at.astimezone(tz).replace(tzinfo=None), + "UPDATED_AT": surl.updated_at.astimezone(tz).replace(tzinfo=None), + "SAVED_RUN": str(surl.saved_run), + "CLICKS": surl.clicks, + "MAX_CLICKS": surl.max_clicks, + "DISABLED": surl.disabled, + } + ) + df = pd.DataFrame.from_records(rows) + return df + + +_hashids = hashids.Hashids(salt=settings.HASHIDS_SALT) + + +class ShortenedURL(models.Model): + url = models.URLField() + + user = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, + related_name="shortened_urls", + null=True, + blank=True, + default=None, + help_text="The user that generated this shortened url", + ) + + saved_runs = models.ManyToManyField( + "bots.SavedRun", + related_name="shortened_urls", + blank=True, + help_text="The runs that are using this shortened url", + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + clicks = models.IntegerField( + default=0, help_text="The number of clicks on this url" + ) + max_clicks = models.IntegerField( + default=0, + help_text="The maximum number of clicks allowed. Set to 0 for no limit.", + ) + disabled = models.BooleanField( + default=False, help_text="Disable this shortened url" + ) + + objects = ShortenedURLQuerySet.as_manager() + + def shortened_url(self) -> str: + return str(furl(settings.APP_BASE_URL) / "2" / _hashids.encode(self.id)) + + shortened_url.short_description = "Shortened URL" + + class Meta: + ordering = ("-created_at",) + get_latest_by = "created_at" + verbose_name = "Shortened URL" + + def __str__(self): + return self.shortened_url() + " -> " + self.url diff --git a/url_shortener/routers.py b/url_shortener/routers.py new file mode 100644 index 000000000..e9872c9e5 --- /dev/null +++ b/url_shortener/routers.py @@ -0,0 +1,25 @@ +from django.db.models import F +from fastapi import APIRouter +from fastapi.responses import RedirectResponse +from fastapi.responses import Response + +from url_shortener.models import ShortenedURL + +app = APIRouter() + + +@app.api_route("/2/{hashid}", methods=["GET", "POST"]) +@app.api_route("/2/{hashid}/", methods=["GET", "POST"]) +def url_shortener(hashid: str): + try: + surl = ShortenedURL.objects.get_by_hashid(hashid) + except ShortenedURL.DoesNotExist: + return Response(status_code=404) + # ensure that the url is not disabled and has not exceeded max clicks + if surl.disabled or (surl.max_clicks and surl.clicks >= surl.max_clicks): + return Response(status_code=410, content="This link has expired") + # increment the click count + ShortenedURL.objects.filter(id=surl.id).update(clicks=F("clicks") + 1) + return RedirectResponse( + url=surl.url, status_code=303 # because youtu.be redirects are 303 + ) diff --git a/url_shortener/tests.py b/url_shortener/tests.py new file mode 100644 index 000000000..5e028329a --- /dev/null +++ b/url_shortener/tests.py @@ -0,0 +1,57 @@ +from starlette.testclient import TestClient + +from daras_ai_v2.functional import map_parallel, flatmap_parallel +from server import app +from url_shortener.models import ShortenedURL + +TEST_URL = "https://www.google.com" + +client = TestClient(app) + + +def test_url_shortener(): + surl = ShortenedURL.objects.create(url=TEST_URL) + short_url = surl.shortened_url() + r = client.get(short_url, allow_redirects=False) + assert r.is_redirect and r.headers["location"] == TEST_URL + + +def test_url_shortener_max_clicks(): + surl = ShortenedURL.objects.create(url=TEST_URL, max_clicks=5) + short_url = surl.shortened_url() + for _ in range(5): + r = client.get(short_url, allow_redirects=False) + assert r.is_redirect and r.headers["location"] == TEST_URL + r = client.get(short_url, allow_redirects=False) + assert r.status_code == 410 + + +def test_url_shortener_disabled(): + surl = ShortenedURL.objects.create(url=TEST_URL, disabled=True) + short_url = surl.shortened_url() + r = client.get(short_url, allow_redirects=False) + assert r.status_code == 410 + + +def test_url_shortener_create_atomic(): + def create(_): + return [ + ShortenedURL.objects.create(url=TEST_URL).shortened_url() + for _ in range(100) + ] + + assert len(set(flatmap_parallel(create, range(5)))) == 500 + + +def test_url_shortener_clicks_decrement_atomic(): + surl = ShortenedURL.objects.create(url=TEST_URL) + short_url = surl.shortened_url() + + def make_clicks(_): + for _ in range(100): + r = client.get(short_url, allow_redirects=False) + assert r.is_redirect and r.headers["location"] == TEST_URL + + map_parallel(make_clicks, range(5)) + + assert ShortenedURL.objects.get(pk=surl.pk).clicks == 500