Skip to content

Commit

Permalink
optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-cranfill committed Sep 28, 2017
1 parent 7d051c5 commit a3f5a63
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 39 deletions.
25 changes: 16 additions & 9 deletions cena/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pandas as pd
from sklearn.svm import SVC

from cena.settings import SHAPE_PREDICTOR_FILE_PATH, CASCADE_FILE_PATH, FEATURE_EXTRACTOR_FILE_PATH, DEV, LABELS_FILE_PATH, REPS_FILE_PATH
from cena.settings import (SHAPE_PREDICTOR_FILE_PATH, CASCADE_FILE_PATH, FEATURE_EXTRACTOR_FILE_PATH,
DEV, LABELS_FILE_PATH, REPS_FILE_PATH, ANNOTATE_FRAME)


class FaceRecognizer(object):
Expand All @@ -21,20 +22,21 @@ def __init__(self):

self.face_cascade = cv2.CascadeClassifier(CASCADE_FILE_PATH)
print('loaded face cascade')
self.clf = self.train_model()
print('classifier trained')
self.clf, self.user_list = self.train_model()
self.net = openface.TorchNeuralNet(FEATURE_EXTRACTOR_FILE_PATH)
print('loaded torch nn')

def train_model(self):
labels = pd.read_csv(LABELS_FILE_PATH, header=None).rename(columns={0:'label', 1:'user'})
labels = pd.read_csv(LABELS_FILE_PATH, header=None).rename(columns={0: 'label', 1: 'user'})
labels.user = labels.user.apply(lambda x: x.split('/')[-2])
user_list = labels.user.unique()
x = pd.read_csv(REPS_FILE_PATH, header=None)

clf = SVC(C=1, kernel='linear', probability=True)
clf.fit(x, labels.user)
print('classifier trained')
return clf
print('users found:', user_list)
return clf, user_list

def output_training_features(self, inpath, outpath):
frame = cv2.imread(inpath)
Expand Down Expand Up @@ -66,14 +68,19 @@ def recognize_faces(self, frame, list_o_faces):
highest_prob_index = np.argmax(pred_probs)
pred_name = self.clf.classes_[highest_prob_index]
pred_prob = max(pred_probs)
pred_names.append({pred_name:pred_prob})
pred_names.append({pred_name: pred_prob})

if DEV:
if DEV and ANNOTATE_FRAME:
pose_landmarks = self.face_pose_predictor(frame, rect)
cv2.putText(frame, '{}: {}'.format(pred_name, pred_prob), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1,
(102, 204, 102), thickness=2)
for point in pose_landmarks.parts():
x, y = point.x, point.y
cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)
end = datetime.now()
return frame, pred_names, (end - start).microseconds / 1000

end = datetime.now()
return frame, pred_names, (end - start).microseconds / 1000
else:
end = datetime.now()
return pred_names, (end - start).microseconds / 1000

12 changes: 9 additions & 3 deletions cena/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

ENVIRONMENT = 'dev'
# ENVIRONMENT = 'nah dude'

DEV = ENVIRONMENT == 'dev'
YOLO_MODE = True
ANNOTATE_FRAME = False

API_SERVER_NAME = 'face-api'

MODELS_DIR = 'data/models/'
SONGS_DIR = 'data/songs/'


CASCADE_FILE_NAME = 'haarcascade_frontalface_default.xml'
SHAPE_PREDICTOR_FILE_NAME = 'shape_predictor_68_face_landmarks.dat'
if DEV:
Expand All @@ -33,5 +35,9 @@
RYAN_FILE_NAME = 'dun_dun_dun.mp3'
RYAN_SONG_PATH = os.path.join(SONGS_DIR, RYAN_FILE_NAME)

from cena.utils import get_api_server_ip_address
# SERVER_URL = 'http://localhost:5000/recognize'
SERVER_URL = 'http://52.207.182.58:5000/recognize'
# SERVER_IP = 'localhost'
SERVER_IP = get_api_server_ip_address()
SERVER_URL = 'http://{}:5000/recognize'.format(SERVER_IP)
# SERVER_URL = 'http://107.20.57.175:5000/recognize'
17 changes: 17 additions & 0 deletions cena/song_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from glob import glob

from cena.settings import SONGS_DIR


def get_name_from_path(file_path):
file_name = file_path.split('/')[-1]
person_name = file_name.split('.')[0]
return person_name


class SongManager(object):
def __init__(self):
self.song_files = song_files =glob(SONGS_DIR + '*.*')
self.person_songs = {get_name_from_path(file_name): file_name
for file_name in song_files}

179 changes: 178 additions & 1 deletion cena/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
import subprocess
import base64
import boto3
import numpy as np

from cena.settings import API_SERVER_NAME


def encode_image(image):
Expand All @@ -10,3 +14,176 @@ def encode_image(image):
def decode_image(encoded_str, shape):
decoded_arr = np.fromstring(base64.b64decode(encoded_str), dtype=np.uint8)
return decoded_arr.reshape(shape)


def play_mp3(path):
process = subprocess.Popen(['mpg123', '-q', path])


def get_api_server_id():
filters = [
{
'Name': 'tag:Name',
'Values': [API_SERVER_NAME]
}
]

ec2_client = boto3.client('ec2')

response = ec2_client.describe_instances(Filters=filters)
instances = response['Reservations']
if instances:
instance_id = instances[0]['Instances'][0]['InstanceId']
return instance_id
else:
raise ValueError('No api server instances found!')


def get_api_server_ip_address():
api_server_id = get_api_server_id()
ec2_manager = boto3.resource('ec2')

instance = ec2_manager.Instance(api_server_id)

if instance.state['Name'] in ['stopped', 'stopping']:
start_instance(instance)
instance_ip = instance.public_ip_address
else:
instance_ip = instance.public_ip_address
print(f'instance already running at {instance_ip}')

# start_if_not_started(instance)
# instance_ip = instance.public_ip_address
return instance_ip


def start_if_not_started(instance):
if instance.state['Name'] in ['stopped', 'stopping']:
start_instance(instance)
else:
instance_ip = instance.public_ip_address
print(f'instance already running at {instance_ip}')


def start_instance(instance):
print(f'Starting instance {instance}...')
response = instance.start()
instance.wait_until_running()
print(f'Instance started at {instance.public_ip_address}')

#
# class EC2Manager(object):
# def __init__(self, config_path=None, start_if_stopped=True, stop_if_started=False):
# self.config = config = Config(config_path)
# self.storm_name, self.storm_enabled = config.storm_name, config.storm_installed
# self.project_name, self.user_name = config.project_name, config.user_name
#
# self.ec2_client = boto3.client('ec2')
# self.ec2_manager = boto3.resource('ec2')
#
# self.instance_id = self.get_instance_id_from_project_name()
# self.instance = instance = self.ec2_manager.Instance(self.instance_id)
# if start_if_stopped:
# self.start_if_not_started()
# elif stop_if_started:
# self.stop_if_not_stopped()
# return
#
# self.instance_ip = instance.public_ip_address
#
# self.public_key_name, self.public_key_path, self.public_key = self.get_public_key()
#
# if self.storm_name and self.storm_enabled:
# self.update_storm()
#
# def start_if_not_started(self):
# if self.instance.state['Name'] in ['stopped', 'stopping']:
# self.start_instance()
# else:
# self.instance_ip = self.instance.public_ip_address
# print(f'instance already started at {self.instance_ip}')
#
# def stop_if_not_stopped(self):
# state = self.instance.state['Name']
# if state in ['pending', 'running']:
# self.stop_instance()
# else:
# print('instance already stopped or is stopping!')
#
# def terminate_instance(self):
# print(f'alrighty then, terminating instance {self.instance_id}...')
# self.instance.terminate()
# self.instance.wait_until_terminated()
# print('instance terminated')
#
# def start_instance(self):
# print(f'Starting instance {self.instance_id}...')
# response = self.instance.start()
# self.instance.wait_until_running()
# print(f'Instance started at {self.instance.public_ip_address}')
#
# def stop_instance(self):
# print(f'Stopping instance {self.instance_id}...')
# response = self.instance.stop()
# self.instance.wait_until_stopped()
# print('Instance stopped')
#
# def update_storm(self):
# print('Fixin\' up a storm...')
# storm = Storm()
# if storm.is_host_in(self.storm_name, regexp_match=True):
# print('Updating storm profile with latest instance ip')
# storm_update(
# name=self.storm_name,
# connection_uri=f'ubuntu@{self.instance_ip}',
# id_file=self.public_key_path,
# o=['LocalForward=8889 127.0.0.1:8888']
# )
# else:
# print('Creating storm profile')
# storm_add(
# name=self.storm_name,
# connection_uri=f'ubuntu@{self.instance_ip}',
# id_file=self.public_key_path,
# o=['LocalForward=8889 127.0.0.1:8888']
# )
# print('Storm updated')
#
# def get_instance_id_from_project_name(self):
# filters = [
# {
# 'Name': 'tag:created_by',
# 'Values': [self.user_name]
# },
# {
# 'Name': 'tag:project_name',
# 'Values': [self.project_name]
# },
# {
# 'Name': 'instance-state-name',
# 'Values': ['pending', 'running', 'stopping', 'stopped']
# }
# ]
#
# response = self.ec2_client.describe_instances(Filters=filters)
# instances = response['Reservations']
# if instances:
# instance_id = instances[0]['Instances'][0]['InstanceId']
# return instance_id
# else:
# print('No instances found!')
# return None
#
# def get_public_key(self):
# public_key_name = self.instance.key_name
# public_key_path = os.path.join(os.path.expanduser('~'), '.ssh', f'{public_key_name}.pem')
# public_key = paramiko.RSAKey.from_private_key_file(public_key_path)
# print(f'Using public key at {public_key_path}')
# return public_key_name, public_key_path, public_key
#
# def resize_instance(self, new_size):
# self.stop_if_not_stopped()
# print(f'changing instance to type {new_size}...')
# self.ec2_client.modify_instance_attribute(InstanceId=self.instance_id, Attribute='instanceType', Value=new_size)
# print(f'instance type changed')
# self.start_if_not_started()
3 changes: 3 additions & 0 deletions check_songs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from cena.song_manager import SongManager

song_manager = SongManager()
44 changes: 26 additions & 18 deletions face_detector.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import subprocess
import cv2
from requests import post
from datetime import datetime
import numpy as np
from requests import request, post

from cena.recognition import FaceRecognizer
from cena.settings import RYAN_SONG_PATH, DEV, CASCADE_FILE_PATH, SERVER_URL
from cena.utils import encode_image, decode_image
from cena.settings import DEV, ANNOTATE_FRAME, CASCADE_FILE_PATH, SERVER_URL
from cena.song_manager import SongManager
from cena.utils import encode_image, decode_image, play_mp3


def play_mp3(path):
process = subprocess.Popen(['mpg123', '-q', path])


def listen_for_quit():
k = cv2.waitKey(1)
Expand All @@ -28,10 +24,14 @@ def get_server_response(frame, list_o_faces):
'shape': shape
}
response = post(SERVER_URL, json=request_json)
frame = decode_image(response.json()['frame'], shape)
people_list = response.json()['people_list']
time = response.json()['time']
return frame, people_list, time

if ANNOTATE_FRAME and DEV:
frame = decode_image(response.json()['frame'], shape)
return frame, people_list, time
else:
return people_list, time
# return response.json()['frame'], response.json()['people_list'], response.json()['time']


Expand All @@ -40,8 +40,9 @@ def process_frame(video_capture, face_recognizer=None):
return
try:
now = datetime.now()
ret, frame = video_capture.read()
frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
ret, frame = video_capture.read(cv2.COLOR_BGR2GRAY)
frame_gray = frame
# frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(
frame_gray,
scaleFactor=1.1,
Expand All @@ -54,15 +55,13 @@ def process_frame(video_capture, face_recognizer=None):
list_o_faces = []
for x, y, w, h in faces:
list_o_faces.append([int(x), int(y), int(w), int(h)])
if DEV:
if DEV and ANNOTATE_FRAME:
# frame, people_list, time = face_recognizer.recognize_faces(frame, list_o_faces)
frame, people_list, time = get_server_response(frame, list_o_faces)
# frame = np.array(frame)
# frame = frame.astype('uint8')
elif DEV:
people_list, time = face_recognizer.recognize_faces(frame, list_o_faces)
else:
frame, people_list, time = get_server_response(frame, list_o_faces)
# frame = np.array(frame)
# frame = frame.astype('uint8')
people_list, time = get_server_response(frame, list_o_faces)
# play_mp3(RYAN_SONG_PATH)
print(people_list, datetime.now() - now)
else:
Expand All @@ -81,9 +80,18 @@ def process_frame(video_capture, face_recognizer=None):
print(error)
return

song_manager = SongManager()
person_songs = song_manager.person_songs
print('found songs:')
print(person_songs)

face_cascade = cv2.CascadeClassifier(CASCADE_FILE_PATH)
if DEV:
face_recognizer = FaceRecognizer()
trained_people = face_recognizer.user_list
print('people i recognize but do not have a song for:')
print([i for i in trained_people if i not in person_songs])

video_capture = cv2.VideoCapture(1)
video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, 320)
video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)
Expand Down
Loading

0 comments on commit a3f5a63

Please sign in to comment.