Skip to content

Commit

Permalink
add daily song manager
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-cranfill committed Sep 29, 2017
1 parent 852149a commit ef270a0
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 200 deletions.
41 changes: 0 additions & 41 deletions cena/FaceDetector.py

This file was deleted.

12 changes: 7 additions & 5 deletions cena/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.svm import SVC

from cena.settings import (SHAPE_PREDICTOR_FILE_PATH, CASCADE_FILE_PATH, FEATURE_EXTRACTOR_FILE_PATH,
DEV, LABELS_FILE_PATH, REPS_FILE_PATH, ANNOTATE_FRAME)
DEV, LABELS_FILE_PATH, REPS_FILE_PATH, ANNOTATE_FRAME, TIME_ZONE)


class FaceRecognizer(object):
Expand Down Expand Up @@ -55,8 +55,10 @@ def make_training_set(self, directory='data/img/*', out_dir='data/transformed_im
self.output_training_features(file_path, os.path.join(out_dir + file_path.split('/')[-1]))

def recognize_faces(self, frame, list_o_faces):
start = datetime.now()
pred_names = []
start = datetime.now(TIME_ZONE)
pred_names = {}
if frame.shape[-1] != 3:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
for x, y, w, h in list_o_faces:
rect = dlib.rectangle(left=x, top=y, right=x+w, bottom=y+h)
aligned_face = self.face_aligner.align(96, frame, rect,
Expand All @@ -68,7 +70,7 @@ def recognize_faces(self, frame, list_o_faces):
highest_prob_index = np.argmax(pred_probs)
pred_name = self.clf.classes_[highest_prob_index]
pred_prob = max(pred_probs)
pred_names.append({pred_name: pred_prob})
pred_names.update({pred_name: pred_prob})

if ANNOTATE_FRAME:
pose_landmarks = self.face_pose_predictor(frame, rect)
Expand All @@ -78,7 +80,7 @@ def recognize_faces(self, frame, list_o_faces):
x, y = point.x, point.y
cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)

end = datetime.now()
end = datetime.now(TIME_ZONE)
return frame, pred_names, (end - start).microseconds / 1000

# if DEV and ANNOTATE_FRAME:
Expand Down
19 changes: 11 additions & 8 deletions cena/settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import pytz
from ast import literal_eval

ENVIRONMENT = 'dev'
# ENVIRONMENT = 'nah dude'
ENVIRONMENT = os.getenv('FACE_ENV', 'lol')
DEV = ENVIRONMENT == 'dev'

YOLO_MODE = True
# YOLO_MODE = False
# YOLO_MODE = True
YOLO_MODE = False

# ANNOTATE_FRAME = True
ANNOTATE_FRAME = False
Expand All @@ -28,7 +28,6 @@
FEATURE_EXTRACTOR_FILE_NAME = 'nn4.small2.v1.t7'
else:
FEATURE_EXTRACTOR_FILE_NAME = 'nn4.small2.v1.t7'
# FEATURE_EXTRACTOR_FILE_NAME = 'nn4.small2.v1.ascii.t7'

CASCADE_FILE_PATH = os.path.join(MODELS_DIR, CASCADE_FILE_NAME)
SHAPE_PREDICTOR_FILE_PATH = os.path.join(MODELS_DIR, SHAPE_PREDICTOR_FILE_NAME)
Expand All @@ -46,12 +45,16 @@
RYAN_FILE_NAME = 'dun_dun_dun.mp3'
RYAN_SONG_PATH = os.path.join(SONGS_DIR, RYAN_FILE_NAME)

if IS_CLIENT:
if not DEV:
from cena.utils import get_api_server_ip_address
# SERVER_URL = 'http://localhost:5000/recognize'
SERVER_IP = get_api_server_ip_address()
# SERVER_URL = 'http://107.20.57.175:5000/recognize'
else:
SERVER_IP = 'localhost'

SERVER_URL = 'http://{}:5000/recognize'.format(SERVER_IP)

TIME_ZONE = pytz.timezone('America/Chicago')

WINDOW_SIZE = 20
MIN_SEEN = 5
PROBA_THRESHOLD = 0.4
62 changes: 60 additions & 2 deletions cena/song_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from subprocess import call
from glob import glob

from cena.settings import SONGS_DIR
from cena.settings import SONGS_DIR, WINDOW_SIZE, PROBA_THRESHOLD, YOLO_MODE, MIN_SEEN
from cena.utils import play_mp3


def get_name_from_path(file_path):
Expand All @@ -11,7 +13,63 @@ def get_name_from_path(file_path):

class SongManager(object):
def __init__(self):
self.song_files = song_files =glob(SONGS_DIR + '*.*')
self.is_blank_slate = True

self.song_files = song_files = glob(SONGS_DIR + '*.*')
self.person_songs = {get_name_from_path(file_name): file_name
for file_name in song_files}
self.people = self.person_songs.keys()

self.person_thresholds = {person: 0. for person in self.people}
self.played_today = self.make_new_slate()
self.window = []

def make_new_slate(self):
return {person: 0 for person in self.people}

def _person_found(self, person):
people_mask = [p == person for p in self.window]
total_found = sum(people_mask)

more_than_half = total_found >= int(WINDOW_SIZE / 2)
half_of_seen = total_found >= int(len(self.window) / 2)
more_than_min = total_found > MIN_SEEN
return more_than_half and more_than_min and half_of_seen

def update_window(self, person, proba):
if proba > PROBA_THRESHOLD:
self.window.append(person)
if len(self.window) > WINDOW_SIZE:
self.window.pop(0)

if self._person_found(person):
# print(self.window)
try:
self.go_song_go(person)
except KeyError as error:
print('oh whoops no song for {}'.format(person))

def go_song_go(self, person):
if self.played_today[person] < 1:
print('playing that funky music for {}'.format(person))
play_mp3(self.person_songs[person])
self.window = []

if not YOLO_MODE:
self.played_today[person] = 1
else:
# print('you\'ve already had your fill today {}'.format(person))
pass

def blank_the_slate(self):
if self.is_blank_slate:
return
self.played_today = self.make_new_slate()
self.window = []
print('oh wow such reset')

# may not be the right place, but don't want to forget
def update_dropbox(self):
# fixme: make this the right command
command = "/home/pi/Dropbox-Uploader/dropbox_uploader.sh download"
call([command], shell=True)
118 changes: 1 addition & 117 deletions cena/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def decode_image(encoded_str, shape):

def play_mp3(path):
process = subprocess.Popen(['mpg123', '-q', path])
process.wait()


def get_api_server_id():
Expand Down Expand Up @@ -70,120 +71,3 @@ def start_instance(instance):
response = instance.start()
instance.wait_until_running()
print('Instance started at {}'.format(instance.public_ip_address))

#
# class EC2Manager(object):
# def __init__(self, config_path=None, start_if_stopped=True, stop_if_started=False):
# self.config = config = Config(config_path)
# self.storm_name, self.storm_enabled = config.storm_name, config.storm_installed
# self.project_name, self.user_name = config.project_name, config.user_name
#
# self.ec2_client = boto3.client('ec2')
# self.ec2_manager = boto3.resource('ec2')
#
# self.instance_id = self.get_instance_id_from_project_name()
# self.instance = instance = self.ec2_manager.Instance(self.instance_id)
# if start_if_stopped:
# self.start_if_not_started()
# elif stop_if_started:
# self.stop_if_not_stopped()
# return
#
# self.instance_ip = instance.public_ip_address
#
# self.public_key_name, self.public_key_path, self.public_key = self.get_public_key()
#
# if self.storm_name and self.storm_enabled:
# self.update_storm()
#
# def start_if_not_started(self):
# if self.instance.state['Name'] in ['stopped', 'stopping']:
# self.start_instance()
# else:
# self.instance_ip = self.instance.public_ip_address
# print(f'instance already started at {self.instance_ip}')
#
# def stop_if_not_stopped(self):
# state = self.instance.state['Name']
# if state in ['pending', 'running']:
# self.stop_instance()
# else:
# print('instance already stopped or is stopping!')
#
# def terminate_instance(self):
# print(f'alrighty then, terminating instance {self.instance_id}...')
# self.instance.terminate()
# self.instance.wait_until_terminated()
# print('instance terminated')
#
# def start_instance(self):
# print(f'Starting instance {self.instance_id}...')
# response = self.instance.start()
# self.instance.wait_until_running()
# print(f'Instance started at {self.instance.public_ip_address}')
#
# def stop_instance(self):
# print(f'Stopping instance {self.instance_id}...')
# response = self.instance.stop()
# self.instance.wait_until_stopped()
# print('Instance stopped')
#
# def update_storm(self):
# print('Fixin\' up a storm...')
# storm = Storm()
# if storm.is_host_in(self.storm_name, regexp_match=True):
# print('Updating storm profile with latest instance ip')
# storm_update(
# name=self.storm_name,
# connection_uri=f'ubuntu@{self.instance_ip}',
# id_file=self.public_key_path,
# o=['LocalForward=8889 127.0.0.1:8888']
# )
# else:
# print('Creating storm profile')
# storm_add(
# name=self.storm_name,
# connection_uri=f'ubuntu@{self.instance_ip}',
# id_file=self.public_key_path,
# o=['LocalForward=8889 127.0.0.1:8888']
# )
# print('Storm updated')
#
# def get_instance_id_from_project_name(self):
# filters = [
# {
# 'Name': 'tag:created_by',
# 'Values': [self.user_name]
# },
# {
# 'Name': 'tag:project_name',
# 'Values': [self.project_name]
# },
# {
# 'Name': 'instance-state-name',
# 'Values': ['pending', 'running', 'stopping', 'stopped']
# }
# ]
#
# response = self.ec2_client.describe_instances(Filters=filters)
# instances = response['Reservations']
# if instances:
# instance_id = instances[0]['Instances'][0]['InstanceId']
# return instance_id
# else:
# print('No instances found!')
# return None
#
# def get_public_key(self):
# public_key_name = self.instance.key_name
# public_key_path = os.path.join(os.path.expanduser('~'), '.ssh', f'{public_key_name}.pem')
# public_key = paramiko.RSAKey.from_private_key_file(public_key_path)
# print(f'Using public key at {public_key_path}')
# return public_key_name, public_key_path, public_key
#
# def resize_instance(self, new_size):
# self.stop_if_not_stopped()
# print(f'changing instance to type {new_size}...')
# self.ec2_client.modify_instance_attribute(InstanceId=self.instance_id, Attribute='instanceType', Value=new_size)
# print(f'instance type changed')
# self.start_if_not_started()
3 changes: 0 additions & 3 deletions check_songs.py

This file was deleted.

Loading

0 comments on commit ef270a0

Please sign in to comment.