|
11 | 11 | from datetime import datetime |
12 | 12 | from functools import cached_property |
13 | 13 | from pathlib import Path |
14 | | -from subprocess import PIPE, CalledProcessError, Popen |
| 14 | +from subprocess import PIPE, Popen |
15 | 15 | from time import sleep |
16 | 16 | from types import FrameType |
17 | 17 | from typing import ( |
@@ -933,11 +933,10 @@ def _buffered_sql_enable_extensions(self, dump: gzip.GzipFile | bz2.BZ2File | IO |
933 | 933 | else: |
934 | 934 | self.unaccent() |
935 | 935 |
|
936 | | - self.pg_vector() |
937 | 936 | dump.seek(0) |
938 | 937 |
|
939 | 938 | @ensure_connected |
940 | | - def unaccent(self): |
| 939 | + def unaccent(self) -> bool: |
941 | 940 | """Install the unaccent extension on the database.""" |
942 | 941 | unaccent_queries = [ |
943 | 942 | "CREATE SCHEMA IF NOT EXISTS unaccent_schema", |
@@ -968,21 +967,33 @@ def unaccent(self): |
968 | 967 | return res |
969 | 968 |
|
970 | 969 | @ensure_connected |
971 | | - def pg_trgm(self): |
| 970 | + def pg_trgm(self) -> bool: |
972 | 971 | """Install the pg_trgm extension on the database.""" |
973 | 972 | pg_trgm_query = "CREATE EXTENSION IF NOT EXISTS pg_trgm WITH SCHEMA public" |
974 | 973 | return self.query(pg_trgm_query) |
975 | 974 |
|
976 | | - def pg_vector(self): |
977 | | - """Install the pgvector extension on the database.""" |
978 | | - pgvector_query = "CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA public" |
| 975 | + def pg_vector(self) -> bool: |
| 976 | + """Install the vector extension on the database. |
| 977 | + Try creating the extension through a regular query first, that will only work if the extension is whitelisted |
| 978 | + in `pgextwlist` or if the user is owner of the extension. If this fails, try enabling it as the `postgres` |
| 979 | + user. |
| 980 | + """ |
| 981 | + pg_vector_query = "CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA public" |
979 | 982 |
|
980 | 983 | try: |
981 | | - return bash.execute(f"sudo -u postgres psql -d {self.name} -c '{pgvector_query}'") |
982 | | - except CalledProcessError as error: |
983 | | - raise OdevError( |
984 | | - "Failed to install pgvector extension, make sure it is installde on your system first" |
985 | | - ) from error |
| 984 | + self.query(pg_vector_query) |
| 985 | + except RuntimeError: |
| 986 | + link = string.link( |
| 987 | + "pgextwlist", |
| 988 | + "https://github.com/dimitri/pgextwlist?tab=readme-ov-file#postgresql-extension-whitelist", |
| 989 | + ) |
| 990 | + logger.error( |
| 991 | + "Failed to install 'pgvector' extension, please ensure it is installed on your system " |
| 992 | + f"and whitelisted with {link}" |
| 993 | + ) |
| 994 | + return False |
| 995 | + |
| 996 | + return True |
986 | 997 |
|
987 | 998 | @ensure_connected |
988 | 999 | def neuter_filestore(self): |
|
0 commit comments