From a3f5a6324490730b4d70ab0e7e092dc0cc31a8e0 Mon Sep 17 00:00:00 2001 From: Ryan Cranfill Date: Wed, 27 Sep 2017 21:52:37 -0500 Subject: [PATCH] optimization --- cena/recognition.py | 25 +++--- cena/settings.py | 12 ++- cena/song_manager.py | 17 ++++ cena/utils.py | 179 ++++++++++++++++++++++++++++++++++++++++++- check_songs.py | 3 + face_detector.py | 44 ++++++----- feature_server.py | 13 ++-- 7 files changed, 254 insertions(+), 39 deletions(-) create mode 100644 cena/song_manager.py create mode 100644 check_songs.py diff --git a/cena/recognition.py b/cena/recognition.py index dad1aab..6d31b83 100644 --- a/cena/recognition.py +++ b/cena/recognition.py @@ -9,7 +9,8 @@ import pandas as pd from sklearn.svm import SVC -from cena.settings import SHAPE_PREDICTOR_FILE_PATH, CASCADE_FILE_PATH, FEATURE_EXTRACTOR_FILE_PATH, DEV, LABELS_FILE_PATH, REPS_FILE_PATH +from cena.settings import (SHAPE_PREDICTOR_FILE_PATH, CASCADE_FILE_PATH, FEATURE_EXTRACTOR_FILE_PATH, + DEV, LABELS_FILE_PATH, REPS_FILE_PATH, ANNOTATE_FRAME) class FaceRecognizer(object): @@ -21,20 +22,21 @@ def __init__(self): self.face_cascade = cv2.CascadeClassifier(CASCADE_FILE_PATH) print('loaded face cascade') - self.clf = self.train_model() - print('classifier trained') + self.clf, self.user_list = self.train_model() self.net = openface.TorchNeuralNet(FEATURE_EXTRACTOR_FILE_PATH) print('loaded torch nn') def train_model(self): - labels = pd.read_csv(LABELS_FILE_PATH, header=None).rename(columns={0:'label', 1:'user'}) + labels = pd.read_csv(LABELS_FILE_PATH, header=None).rename(columns={0: 'label', 1: 'user'}) labels.user = labels.user.apply(lambda x: x.split('/')[-2]) + user_list = labels.user.unique() x = pd.read_csv(REPS_FILE_PATH, header=None) clf = SVC(C=1, kernel='linear', probability=True) clf.fit(x, labels.user) print('classifier trained') - return clf + print('users found:', user_list) + return clf, user_list def output_training_features(self, inpath, outpath): frame = cv2.imread(inpath) @@ -66,14 +68,19 @@ def recognize_faces(self, frame, list_o_faces): highest_prob_index = np.argmax(pred_probs) pred_name = self.clf.classes_[highest_prob_index] pred_prob = max(pred_probs) - pred_names.append({pred_name:pred_prob}) + pred_names.append({pred_name: pred_prob}) - if DEV: + if DEV and ANNOTATE_FRAME: pose_landmarks = self.face_pose_predictor(frame, rect) cv2.putText(frame, '{}: {}'.format(pred_name, pred_prob), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (102, 204, 102), thickness=2) for point in pose_landmarks.parts(): x, y = point.x, point.y cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) - end = datetime.now() - return frame, pred_names, (end - start).microseconds / 1000 \ No newline at end of file + + end = datetime.now() + return frame, pred_names, (end - start).microseconds / 1000 + else: + end = datetime.now() + return pred_names, (end - start).microseconds / 1000 + diff --git a/cena/settings.py b/cena/settings.py index 8d4bed1..8771090 100644 --- a/cena/settings.py +++ b/cena/settings.py @@ -2,13 +2,15 @@ ENVIRONMENT = 'dev' # ENVIRONMENT = 'nah dude' - DEV = ENVIRONMENT == 'dev' +YOLO_MODE = True +ANNOTATE_FRAME = False + +API_SERVER_NAME = 'face-api' MODELS_DIR = 'data/models/' SONGS_DIR = 'data/songs/' - CASCADE_FILE_NAME = 'haarcascade_frontalface_default.xml' SHAPE_PREDICTOR_FILE_NAME = 'shape_predictor_68_face_landmarks.dat' if DEV: @@ -33,5 +35,9 @@ RYAN_FILE_NAME = 'dun_dun_dun.mp3' RYAN_SONG_PATH = os.path.join(SONGS_DIR, RYAN_FILE_NAME) +from cena.utils import get_api_server_ip_address # SERVER_URL = 'http://localhost:5000/recognize' -SERVER_URL = 'http://52.207.182.58:5000/recognize' +# SERVER_IP = 'localhost' +SERVER_IP = get_api_server_ip_address() +SERVER_URL = 'http://{}:5000/recognize'.format(SERVER_IP) +# SERVER_URL = 'http://107.20.57.175:5000/recognize' diff --git a/cena/song_manager.py b/cena/song_manager.py new file mode 100644 index 0000000..828f65d --- /dev/null +++ b/cena/song_manager.py @@ -0,0 +1,17 @@ +from glob import glob + +from cena.settings import SONGS_DIR + + +def get_name_from_path(file_path): + file_name = file_path.split('/')[-1] + person_name = file_name.split('.')[0] + return person_name + + +class SongManager(object): + def __init__(self): + self.song_files = song_files =glob(SONGS_DIR + '*.*') + self.person_songs = {get_name_from_path(file_name): file_name + for file_name in song_files} + diff --git a/cena/utils.py b/cena/utils.py index 13fe45a..841602b 100644 --- a/cena/utils.py +++ b/cena/utils.py @@ -1,5 +1,9 @@ -import numpy as np +import subprocess import base64 +import boto3 +import numpy as np + +from cena.settings import API_SERVER_NAME def encode_image(image): @@ -10,3 +14,176 @@ def encode_image(image): def decode_image(encoded_str, shape): decoded_arr = np.fromstring(base64.b64decode(encoded_str), dtype=np.uint8) return decoded_arr.reshape(shape) + + +def play_mp3(path): + process = subprocess.Popen(['mpg123', '-q', path]) + + +def get_api_server_id(): + filters = [ + { + 'Name': 'tag:Name', + 'Values': [API_SERVER_NAME] + } + ] + + ec2_client = boto3.client('ec2') + + response = ec2_client.describe_instances(Filters=filters) + instances = response['Reservations'] + if instances: + instance_id = instances[0]['Instances'][0]['InstanceId'] + return instance_id + else: + raise ValueError('No api server instances found!') + + +def get_api_server_ip_address(): + api_server_id = get_api_server_id() + ec2_manager = boto3.resource('ec2') + + instance = ec2_manager.Instance(api_server_id) + + if instance.state['Name'] in ['stopped', 'stopping']: + start_instance(instance) + instance_ip = instance.public_ip_address + else: + instance_ip = instance.public_ip_address + print(f'instance already running at {instance_ip}') + + # start_if_not_started(instance) + # instance_ip = instance.public_ip_address + return instance_ip + + +def start_if_not_started(instance): + if instance.state['Name'] in ['stopped', 'stopping']: + start_instance(instance) + else: + instance_ip = instance.public_ip_address + print(f'instance already running at {instance_ip}') + + +def start_instance(instance): + print(f'Starting instance {instance}...') + response = instance.start() + instance.wait_until_running() + print(f'Instance started at {instance.public_ip_address}') + +# +# class EC2Manager(object): +# def __init__(self, config_path=None, start_if_stopped=True, stop_if_started=False): +# self.config = config = Config(config_path) +# self.storm_name, self.storm_enabled = config.storm_name, config.storm_installed +# self.project_name, self.user_name = config.project_name, config.user_name +# +# self.ec2_client = boto3.client('ec2') +# self.ec2_manager = boto3.resource('ec2') +# +# self.instance_id = self.get_instance_id_from_project_name() +# self.instance = instance = self.ec2_manager.Instance(self.instance_id) +# if start_if_stopped: +# self.start_if_not_started() +# elif stop_if_started: +# self.stop_if_not_stopped() +# return +# +# self.instance_ip = instance.public_ip_address +# +# self.public_key_name, self.public_key_path, self.public_key = self.get_public_key() +# +# if self.storm_name and self.storm_enabled: +# self.update_storm() +# +# def start_if_not_started(self): +# if self.instance.state['Name'] in ['stopped', 'stopping']: +# self.start_instance() +# else: +# self.instance_ip = self.instance.public_ip_address +# print(f'instance already started at {self.instance_ip}') +# +# def stop_if_not_stopped(self): +# state = self.instance.state['Name'] +# if state in ['pending', 'running']: +# self.stop_instance() +# else: +# print('instance already stopped or is stopping!') +# +# def terminate_instance(self): +# print(f'alrighty then, terminating instance {self.instance_id}...') +# self.instance.terminate() +# self.instance.wait_until_terminated() +# print('instance terminated') +# +# def start_instance(self): +# print(f'Starting instance {self.instance_id}...') +# response = self.instance.start() +# self.instance.wait_until_running() +# print(f'Instance started at {self.instance.public_ip_address}') +# +# def stop_instance(self): +# print(f'Stopping instance {self.instance_id}...') +# response = self.instance.stop() +# self.instance.wait_until_stopped() +# print('Instance stopped') +# +# def update_storm(self): +# print('Fixin\' up a storm...') +# storm = Storm() +# if storm.is_host_in(self.storm_name, regexp_match=True): +# print('Updating storm profile with latest instance ip') +# storm_update( +# name=self.storm_name, +# connection_uri=f'ubuntu@{self.instance_ip}', +# id_file=self.public_key_path, +# o=['LocalForward=8889 127.0.0.1:8888'] +# ) +# else: +# print('Creating storm profile') +# storm_add( +# name=self.storm_name, +# connection_uri=f'ubuntu@{self.instance_ip}', +# id_file=self.public_key_path, +# o=['LocalForward=8889 127.0.0.1:8888'] +# ) +# print('Storm updated') +# +# def get_instance_id_from_project_name(self): +# filters = [ +# { +# 'Name': 'tag:created_by', +# 'Values': [self.user_name] +# }, +# { +# 'Name': 'tag:project_name', +# 'Values': [self.project_name] +# }, +# { +# 'Name': 'instance-state-name', +# 'Values': ['pending', 'running', 'stopping', 'stopped'] +# } +# ] +# +# response = self.ec2_client.describe_instances(Filters=filters) +# instances = response['Reservations'] +# if instances: +# instance_id = instances[0]['Instances'][0]['InstanceId'] +# return instance_id +# else: +# print('No instances found!') +# return None +# +# def get_public_key(self): +# public_key_name = self.instance.key_name +# public_key_path = os.path.join(os.path.expanduser('~'), '.ssh', f'{public_key_name}.pem') +# public_key = paramiko.RSAKey.from_private_key_file(public_key_path) +# print(f'Using public key at {public_key_path}') +# return public_key_name, public_key_path, public_key +# +# def resize_instance(self, new_size): +# self.stop_if_not_stopped() +# print(f'changing instance to type {new_size}...') +# self.ec2_client.modify_instance_attribute(InstanceId=self.instance_id, Attribute='instanceType', Value=new_size) +# print(f'instance type changed') +# self.start_if_not_started() \ No newline at end of file diff --git a/check_songs.py b/check_songs.py new file mode 100644 index 0000000..7eee91e --- /dev/null +++ b/check_songs.py @@ -0,0 +1,3 @@ +from cena.song_manager import SongManager + +song_manager = SongManager() diff --git a/face_detector.py b/face_detector.py index 8933176..60dac5d 100755 --- a/face_detector.py +++ b/face_detector.py @@ -1,17 +1,13 @@ -import subprocess import cv2 +from requests import post from datetime import datetime -import numpy as np -from requests import request, post from cena.recognition import FaceRecognizer -from cena.settings import RYAN_SONG_PATH, DEV, CASCADE_FILE_PATH, SERVER_URL -from cena.utils import encode_image, decode_image +from cena.settings import DEV, ANNOTATE_FRAME, CASCADE_FILE_PATH, SERVER_URL +from cena.song_manager import SongManager +from cena.utils import encode_image, decode_image, play_mp3 -def play_mp3(path): - process = subprocess.Popen(['mpg123', '-q', path]) - def listen_for_quit(): k = cv2.waitKey(1) @@ -28,10 +24,14 @@ def get_server_response(frame, list_o_faces): 'shape': shape } response = post(SERVER_URL, json=request_json) - frame = decode_image(response.json()['frame'], shape) people_list = response.json()['people_list'] time = response.json()['time'] - return frame, people_list, time + + if ANNOTATE_FRAME and DEV: + frame = decode_image(response.json()['frame'], shape) + return frame, people_list, time + else: + return people_list, time # return response.json()['frame'], response.json()['people_list'], response.json()['time'] @@ -40,8 +40,9 @@ def process_frame(video_capture, face_recognizer=None): return try: now = datetime.now() - ret, frame = video_capture.read() - frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + ret, frame = video_capture.read(cv2.COLOR_BGR2GRAY) + frame_gray = frame + # frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale( frame_gray, scaleFactor=1.1, @@ -54,15 +55,13 @@ def process_frame(video_capture, face_recognizer=None): list_o_faces = [] for x, y, w, h in faces: list_o_faces.append([int(x), int(y), int(w), int(h)]) - if DEV: + if DEV and ANNOTATE_FRAME: # frame, people_list, time = face_recognizer.recognize_faces(frame, list_o_faces) frame, people_list, time = get_server_response(frame, list_o_faces) - # frame = np.array(frame) - # frame = frame.astype('uint8') + elif DEV: + people_list, time = face_recognizer.recognize_faces(frame, list_o_faces) else: - frame, people_list, time = get_server_response(frame, list_o_faces) - # frame = np.array(frame) - # frame = frame.astype('uint8') + people_list, time = get_server_response(frame, list_o_faces) # play_mp3(RYAN_SONG_PATH) print(people_list, datetime.now() - now) else: @@ -81,9 +80,18 @@ def process_frame(video_capture, face_recognizer=None): print(error) return +song_manager = SongManager() +person_songs = song_manager.person_songs +print('found songs:') +print(person_songs) + face_cascade = cv2.CascadeClassifier(CASCADE_FILE_PATH) if DEV: face_recognizer = FaceRecognizer() + trained_people = face_recognizer.user_list + print('people i recognize but do not have a song for:') + print([i for i in trained_people if i not in person_songs]) + video_capture = cv2.VideoCapture(1) video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, 320) video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) diff --git a/feature_server.py b/feature_server.py index e3add6f..0170f70 100644 --- a/feature_server.py +++ b/feature_server.py @@ -5,6 +5,7 @@ from cena.recognition import FaceRecognizer from cena.utils import decode_image, encode_image +from cena.settings import DEV, ANNOTATE_FRAME RECOGNIZER = FaceRecognizer() app = Flask(__name__) @@ -32,19 +33,15 @@ def recognize(): frame = decode_image(encoded_frame, shape) list_o_faces = request.json['list_o_faces'] - # print(type(list_o_faces[0][0])) - frame, people_list, time = RECOGNIZER.recognize_faces(frame, list_o_faces) - # response = { - # 'people_list': people_list, - # 'frame': frame.tolist(), - # 'time': time - # } + response = { 'people_list': people_list, - 'frame': encode_image(frame), 'time': time } + + if DEV and ANNOTATE_FRAME: + response.update({'frame': encode_image(frame)}) return jsonify(response) if __name__ == '__main__':