#!/usr/bin/env python

#import device_patches       # Device specific patches for Jetson Nano (needs to be before importing cv2)

import cv2
import os
import sys, getopt
import signal
import time
import serial
import serial.tools.list_ports
from edge_impulse_linux.image import ImageImpulseRunner
from flask import Response
from flask import Flask
from flask import render_template
from threading import Lock, Thread

class SenderThread(Thread):
    def __init__(self, serialNo):
        Thread.__init__(self)
        # Erlaubt automatische Terminierung bei Programmhalt
        self.daemon = True
        self.serialNo = serialNo
        self.errors = None
        self._replies = []
        self._messages = []
        self._stopping = False

    def check_connection(self):
        ser = None

        # Port aufzählen, bei gefundenem Zielgerät Status melden
        for port in serial.tools.list_ports.comports():
            if port.serial_number == self.serialNo:
                # Port öffnen nach erster bzw. erneuter Anwesenheit
                ser = serial.Serial(port[0], timeout=2.5)
                # Normal-Status absenden und Lesevorgang starten
                self.errors = None
                return ser
        # Kein Gerät gefunden - Fehler auswerfen (wird extern abgehandelt)
        else:
            # Fehler-Status setzen
            if not self.errors == "nodevice":
                self.errors = "nodevice"
            # Kurz warten, dann Fehler auswerfen
            time.sleep(.5)
            raise serial.SerialException()

    def run(self):
        ser = None

        # Wiederholt nach Empfänger suchen, weil unterwartete Trennung des Geräts unter Windows nicht fehlschlägt
        while not self._stopping or ser is None:
            try:
                if not ser:
                    ser = self.check_connection()
                if len(self._messages) > 0:
                    ser.write(self._messages.pop())
                data = ser.readline()
                # Bei eigenhendem Stoppsignal sofort abbrechen
                if self._stopping:
                    ser.close()
                    return
                if data:
                    self._replies.insert(0, data.decode("utf-8").strip())

            # Fehler (abgetrennt/nicht gefunden) abhandeln und Suche wiederholen
            except serial.SerialException:
                if ser:
                    ser.close()
                    ser = None

        ser.close()

    def send(self, msg):
        self._messages.insert(0, bytes('%s\r\n' % msg, "utf-8"))

    def receive(self):
        if len(self._replies) > 0:
            return self._replies.pop()
        return None

    def abort(self):
        self._stopping = True

runner = sender = None

# initialize the output frame and a lock used to ensure thread-safe
# exchanges of the output frames (useful when multiple browsers/tabs
# are viewing the stream)
outputFrame = None
cameraFeedInfo = projectName = "(N/A)"
lock = Lock()
app = Flask(__name__)

def now():
    return round(time.time() * 1000)

def get_webcams():
    port_ids = []
    for port in range(5):
        print("Looking for a camera in port %s:" %port)
        camera = cv2.VideoCapture(port)
        if camera.isOpened():
            ret = camera.read()[0]
            if ret:
                backendName =camera.getBackendName()
                w = camera.get(3)
                h = camera.get(4)
                print("Camera %s (%s x %s) found in port %s " %(backendName,h,w, port))
                port_ids.append(port)
            camera.release()
    return port_ids

def sigint_handler(sig, frame):
    print('Interrupted')
    if (sender):
        sender.abort()
    if (runner):
        runner.stop()
    sys.exit(0)

signal.signal(signal.SIGINT, sigint_handler)

def help():
    print('python classify.py <RECEIVING_DEVICE_SERIAL> <path_to_model.eim> [<CAMERA_PORT_INDEX>]')

def main(argv):
    try:
        opts, args = getopt.getopt(argv, "h", ["--help"])
    except getopt.GetoptError:
        help()
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-h', '--help'):
            help()
            sys.exit()

    if len(args) < 2:
        help()
        sys.exit(2)

    global sender, runner
    sender = SenderThread(args[0])
    sender.start()
    model = args[1]

    dir_path = os.path.dirname(os.path.realpath(__file__))
    modelfile = os.path.join(dir_path, model)

    print('MODEL: ' + modelfile)

    with ImageImpulseRunner(modelfile) as runner:
        # grab global references to the video stream, output frame, and
        # lock variables
        global outputFrame, lock, cameraFeedInfo, projectName

        try:
            model_info = runner.init()
            projectName = model_info['project']['name']
            print('Loaded runner for "' + model_info['project']['owner'] + ' / ' + model_info['project']['name'] + '"')
            labels = model_info['model_parameters']['labels']
            if len(args)>= 3:
                videoCaptureDeviceId = int(args[2])
            else:
                port_ids = get_webcams()
                if len(port_ids) == 0:
                    raise Exception('Cannot find any webcams')
                if len(args)<= 2 and len(port_ids)> 1:
                    raise Exception("Multiple cameras found. Add the camera port ID as a second argument to use to this script")
                videoCaptureDeviceId = int(port_ids[0])

            camera = cv2.VideoCapture(videoCaptureDeviceId)
            ret = camera.read()[0]
            if ret:
                backendName = camera.getBackendName()
                w = camera.get(3)
                h = camera.get(4)
                cameraFeedInfo = "%s (%s x %s)" % (backendName,h,w)
                print("%s in port %s selected." % (cameraFeedInfo, videoCaptureDeviceId))
                camera.release()
            else:
                raise Exception("Couldn't initialize selected camera.")

            next_frame = 0 # limit to ~10 fps here

            for res, img in runner.classifier(videoCaptureDeviceId):
                if (next_frame > now()):
                    time.sleep((next_frame - now()) / 1000)

                # print('classification runner response', res)

                if "classification" in res["result"].keys():
                    print('Result (%d ms.) ' % (res['timing']['dsp'] + res['timing']['classification']), end='')
                    for label in labels:
                        score = res['result']['classification'][label]
                        print('%s: %.2f\t' % (label, score), end='')
                    print('', flush=True)

                elif "bounding_boxes" in res["result"].keys():
                    print('Found %d bounding boxes (%d ms.)' % (len(res["result"]["bounding_boxes"]), res['timing']['dsp'] + res['timing']['classification']))
                    for bb in res["result"]["bounding_boxes"]:
                        print('\t%s (%.2f): x=%d y=%d w=%d h=%d' % (bb['label'], bb['value'], bb['x'], bb['y'], bb['width'], bb['height']))
                        img = cv2.rectangle(img, (bb['x'], bb['y']), (bb['x'] + bb['width'], bb['y'] + bb['height']), (255, 0, 0), 1)

                    with lock:
                        outputFrame = cv2.resize(cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR), (320, 320), interpolation=cv2.INTER_NEAREST)

                    while sender.errors is not None:
                        print('Controlling device error: %s' % sender.errors)
                        time.sleep(.5)
                    sender.send(len(res["result"]["bounding_boxes"]))
                    reply = None
                    while reply is None:
                        reply = sender.receive()
                    print('Got reply from controlling device: %s' % reply)

                next_frame = now() + 100
        finally:
            if (runner):
                runner.stop()

def generate():
	# grab global references to the output frame and lock variables
	global outputFrame, lock
 
	# loop over frames from the output stream
	while True:
		# wait until the lock is acquired
		with lock:
			# check if the output frame is available, otherwise skip
			# the iteration of the loop
			if outputFrame is None:
				continue
 
			# encode the frame in JPEG format
			(flag, encodedImage) = cv2.imencode(".jpg", outputFrame)
 
			# ensure the frame was successfully encoded
			if not flag:
				continue
 
		# yield the output frame in the byte format
		yield(b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + bytearray(encodedImage) + b'\r\n')

@app.route("/video_feed")
def video_feed():
	# return the response generated along with the specific media
	# type (mime type)
	return Response(generate(), mimetype = "multipart/x-mixed-replace; boundary=frame")

@app.route("/")
def index():
	# return the rendered template
	return render_template("index.html", camera_feed_info=cameraFeedInfo, project_name=projectName)

if __name__ == "__main__":
    apprunner = Thread(target=app.run, kwargs={
        'host': "0.0.0.0", 'port': 8080, 'debug': True, 'threaded': True, 'use_reloader': False
    })
    apprunner.daemon = True
    apprunner.start()
    main(sys.argv[1:])
