Skip to content

Commit

Permalink
Merge pull request #452 from xchem/black
Browse files Browse the repository at this point in the history
Ran black on the b/e repo
  • Loading branch information
alanbchristie authored Nov 30, 2023
2 parents da7e5a7 + b56255e commit a6f6453
Show file tree
Hide file tree
Showing 83 changed files with 5,823 additions and 2,332 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ repos:
# - types-pytz
# - types-requests

# Black (uncompromising) Python code formatter
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
args:
- --skip-string-normalization
- --target-version
- py311

# Pylint
# To check import errors we need to install every package
# used by the DM. This is often impractical on the client,
Expand Down
88 changes: 53 additions & 35 deletions api/remote_ispyb_connector.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,72 @@
import threading
import mysql.connector
from ispyb.connector.mysqlsp.main import ISPyBMySQLSPConnector as Connector
from ispyb.exception import (ISPyBConnectionException, ISPyBNoResultException,
ISPyBRetrieveFailed, ISPyBWriteFailed)
from ispyb.exception import (
ISPyBConnectionException,
ISPyBNoResultException,
ISPyBRetrieveFailed,
ISPyBWriteFailed,
)
import sshtunnel
import time
import pymysql


class SSHConnector(Connector):
def __init__(self,
user=None,
pw=None,
host="localhost",
db=None,
port=3306,
reconn_attempts=6,
reconn_delay=1,
remote=False,
ssh_user=None,
ssh_password=None,
ssh_host=None,
conn_inactivity=360,
):
def __init__(
self,
user=None,
pw=None,
host="localhost",
db=None,
port=3306,
reconn_attempts=6,
reconn_delay=1,
remote=False,
ssh_user=None,
ssh_password=None,
ssh_host=None,
conn_inactivity=360,
):
self.conn_inactivity = conn_inactivity
self.lock = threading.Lock()
self.server = None

if remote:
creds = {'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'db_host': host,
'db_port': int(port),
'db_user': user,
'db_pass': pw,
'db_name': db}
creds = {
'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'db_host': host,
'db_port': int(port),
'db_user': user,
'db_pass': pw,
'db_name': db,
}
self.remote_connect(**creds)

else:
self.connect(user=user, pw=pw, host=host, db=db, port=port, conn_inactivity=conn_inactivity)

def remote_connect(self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name):

self.connect(
user=user,
pw=pw,
host=host,
db=db,
port=port,
conn_inactivity=conn_inactivity,
)

def remote_connect(
self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name
):
sshtunnel.SSH_TIMEOUT = 10.0
sshtunnel.TUNNEL_TIMEOUT = 10.0
self.conn_inactivity = int(self.conn_inactivity)

self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user, ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port)
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)

# stops hanging connections in transport
Expand All @@ -60,8 +76,10 @@ def remote_connect(self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user
self.server.start()

self.conn = pymysql.connect(
user=db_user, password=db_pass,
host='127.0.0.1', port=self.server.local_bind_port,
user=db_user,
password=db_pass,
host='127.0.0.1',
port=self.server.local_bind_port,
database=db_name,
)

Expand Down Expand Up @@ -91,13 +109,13 @@ def call_sp_retrieve(self, procname, args):
try:
cursor.callproc(procname=procname, args=args)
except DataError as e:
raise ISPyBRetrieveFailed("DataError({0}): {1}".format(e.errno, traceback.format_exc()))
raise ISPyBRetrieveFailed(
"DataError({0}): {1}".format(e.errno, traceback.format_exc())
)

result = cursor.fetchall()

cursor.close()
if result == []:
raise ISPyBNoResultException
return result


68 changes: 42 additions & 26 deletions api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


def get_remote_conn():

ispyb_credentials = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
Expand All @@ -52,7 +51,7 @@ def get_remote_conn():
'ssh_host': os.environ.get("SSH_HOST"),
'ssh_user': os.environ.get("SSH_USER"),
'ssh_password': os.environ.get("SSH_PASSWORD"),
'remote': True
'remote': True,
}

ispyb_credentials.update(**ssh_credentials)
Expand Down Expand Up @@ -102,7 +101,6 @@ def get_conn():


class ISpyBSafeQuerySet(viewsets.ReadOnlyModelViewSet):

def get_queryset(self):
"""
Optionally restricts the returned purchases to a given proposals
Expand All @@ -115,8 +113,11 @@ def get_queryset(self):
if open_proposal not in proposal_list:
proposal_list.append(open_proposal)

logger.debug('is_authenticated=%s, proposal_list=%s',
self.request.user.is_authenticated, proposal_list)
logger.debug(
'is_authenticated=%s, proposal_list=%s',
self.request.user.is_authenticated,
proposal_list,
)

# Must have a foreign key to a Project for this filter to work.
# get_q_filter() returns a Q expression for filtering
Expand All @@ -130,7 +131,7 @@ def get_open_proposals(self):
if os.environ.get("TEST_SECURITY_FLAG", False):
return ["lb00000"]
else:
# All of well-known (built-in) public Projects (Proposals/Visits)
# All of well-known (built-in) public Projects (Proposals/Visits)
return ["lb27156"]

def get_proposals_for_user_from_django(self, user):
Expand All @@ -142,8 +143,12 @@ def get_proposals_for_user_from_django(self, user):
prop_ids = list(
Project.objects.filter(user_id=user.pk).values_list("title", flat=True)
)
logger.debug("Got %s proposals for user %s: %s",
len(prop_ids), user.username, prop_ids)
logger.debug(
"Got %s proposals for user %s: %s",
len(prop_ids),
user.username,
prop_ids,
)
return prop_ids

def needs_updating(self, user):
Expand All @@ -163,7 +168,6 @@ def needs_updating(self, user):
return False

def run_query_with_connector(self, conn, user):

core = conn.core
try:
rs = core.retrieve_sessions_for_person_login(user.username)
Expand Down Expand Up @@ -242,22 +246,30 @@ def get_proposals_for_user_from_ispyb(self, user):

# Always display the collected results for the user.
# These will be cached.
logger.info("Got %s proposals from %s records for user %s: %s",
len(prop_id_set), len(rs), user.username, prop_id_set)
logger.info(
"Got %s proposals from %s records for user %s: %s",
len(prop_id_set),
len(rs),
user.username,
prop_id_set,
)

# Cache the result and return the result for the user
USER_LIST_DICT[user.username]["RESULTS"] = list(prop_id_set)
return USER_LIST_DICT[user.username]["RESULTS"]
else:
# Return the previous query (cached for an hour)
cached_prop_ids = USER_LIST_DICT[user.username]["RESULTS"]
logger.info("Got %s cached proposals for user %s: %s",
len(cached_prop_ids), user.username, cached_prop_ids)
logger.info(
"Got %s cached proposals for user %s: %s",
len(cached_prop_ids),
user.username,
cached_prop_ids,
)
return cached_prop_ids

def get_proposals_for_user(self, user):
"""Returns a list of proposals (public and private) that the user has access to.
"""
"""Returns a list of proposals (public and private) that the user has access to."""
assert user

ispyb_user = os.environ.get("ISPYB_USER")
Expand All @@ -267,21 +279,23 @@ def get_proposals_for_user(self, user):
logger.info("Getting proposals from ISPyB...")
return self.get_proposals_for_user_from_ispyb(user)
else:
logger.info("No proposals (user %s is not authenticated)", user.username)
logger.info(
"No proposals (user %s is not authenticated)", user.username
)
return []
else:
logger.info("Getting proposals from Django...")
return self.get_proposals_for_user_from_django(user)

def get_q_filter(self, proposal_list):
"""Returns a Q expression representing a (potentially complex) table filter.
"""
"""Returns a Q expression representing a (potentially complex) table filter."""
if self.filter_permissions:
# Q-filter is based on the filter_permissions string
# whether the resultant Project title in the proposal list
# whether the resultant Project title in the proposal list
# OR where the Project is 'open_to_public'
return Q(**{self.filter_permissions + "__title__in": proposal_list}) |\
Q(**{self.filter_permissions + "__open_to_public": True})
return Q(**{self.filter_permissions + "__title__in": proposal_list}) | Q(
**{self.filter_permissions + "__open_to_public": True}
)
else:
# No filter permission?
# Assume this QuerySet is used for the Project model.
Expand All @@ -293,7 +307,6 @@ def get_q_filter(self, proposal_list):


class ISpyBSafeStaticFiles:

def get_queryset(self):
query = ISpyBSafeQuerySet()
query.request = self.request
Expand All @@ -318,12 +331,16 @@ def get_response(self):
logger.info("Path to pass to nginx: %s", self.prefix + file_name)

if hasattr(self, 'file_format'):
if self.file_format=='raw':
if self.file_format == 'raw':
file_field = getattr(object, self.field_name)
filepath = file_field.path
zip_file = open(filepath, 'rb')
response = HttpResponse(FileWrapper(zip_file), content_type='application/zip')
response['Content-Disposition'] = 'attachment; filename="%s"' % file_name
response = HttpResponse(
FileWrapper(zip_file), content_type='application/zip'
)
response['Content-Disposition'] = (
'attachment; filename="%s"' % file_name
)

else:
response = HttpResponse()
Expand All @@ -338,7 +355,6 @@ def get_response(self):


class ISpyBSafeStaticFiles2(ISpyBSafeStaticFiles):

def get_response(self):
logger.info("+ get_response called with: %s", self.input_string)
# it wasn't working because found two objects with test file name
Expand Down
Loading

0 comments on commit a6f6453

Please sign in to comment.