diff --git a/axis-ptz/camera.py b/axis-ptz/camera.py index 4aedb2e..2459dc0 100755 --- a/axis-ptz/camera.py +++ b/axis-ptz/camera.py @@ -32,11 +32,18 @@ cameraZoom = None cameraMoveSpeed = None cameraDelay = None +object_topic = None +flight_topic = None pan = 0 tilt = 0 actualPan = 0 actualTilt = 0 +follow_x = 0 +follow_y = 0 +actualX = 0 +actualY = 0 currentPlane=None +object_timeout=0 # https://stackoverflow.com/questions/45659723/calculate-the-difference-between-two-compass-headings-python @@ -58,6 +65,14 @@ def getHeadingDiff(h1, h2): else: return absDiff - 360 +def setXY(x,y): + global follow_x + global follow_y + + follow_x = int(x) + follow_y = int(y) + + def setPan(bearing): global pan camera_bearing = args.bearing @@ -109,6 +124,7 @@ def get_jpeg_request(): # 5.2.4.1 """ payload = { 'resolution': "1920x1080", + 'compression': 5, 'camera': 1, } url = 'http://' + args.axis_ip + '/axis-cgi/jpg/image.cgi' @@ -132,25 +148,84 @@ def get_jpeg_request(): # 5.2.4.1 text += str(resp.text) return text +def get_bmp_request(): # 5.2.4.1 + """ + The requests specified in the JPEG/MJPG section are supported by those video products + that use JPEG and MJPG encoding. + Args: + resolution: Resolution of the returned image. Check the product’s Release notes. + camera: Selects the source camera or the quad stream. + square_pixel: Enable/disable square pixel correction. Applies only to video encoders. + compression: Adjusts the compression level of the image. + clock: Shows/hides the time stamp. (0 = hide, 1 = show) + date: Shows/hides the date. (0 = hide, 1 = show) + text: Shows/hides the text. (0 = hide, 1 = show) + text_string: The text shown in the image, the string must be URL encoded. + text_color: The color of the text shown in the image. (black, white) + text_background_color: The color of the text background shown in the image. + (black, white, transparent, semitransparent) + rotation: Rotate the image clockwise. + text_position: The position of the string shown in the image. (top, bottom) + overlay_image: Enable/disable overlay image.(0 = disable, 1 = enable) + overlay_position:The x and y coordinates defining the position of the overlay image. + (x) + Returns: + Success ('image save' and save the image in the file folder) or Failure (Error and + description). + """ + payload = { + 'resolution': "1920x1080", + 'camera': 1, + } + url = 'http://' + args.axis_ip + '/axis-cgi/bitmap/image.bmp' + resp = requests.get(url, auth=HTTPDigestAuth(args.axis_username, args.axis_password), + params=payload) + + if resp.status_code == 200: + captureDir = "capture/{}".format(currentPlane["type"]) + try: + os.makedirs(captureDir) + except OSError as e: + if e.errno != errno.EEXIST: + raise # This was not a "directory exist" error.. + filename = "{}/{}_{}.bmp".format(captureDir, currentPlane["icao24"],datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) + + with open(filename, 'wb') as var: + var.write(resp.content) + return str('Image saved') + + text = str(resp) + text += str(resp.text) + return text def moveCamera(): global actualPan global actualTilt + global actualX + global actualY global camera while True: lockedOn = False - if actualTilt != tilt or actualPan != pan: - logging.info("Moving camera to Tilt: %d & Pan: %d"%(tilt, pan)) - actualTilt = tilt - actualPan = pan - lockedOn = True - camera.absolute_move(pan, tilt, cameraZoom, cameraMoveSpeed) - time.sleep(0.3) - get_jpeg_request() - - + if (object_timeout < time.mktime(time.gmtime())): + if actualTilt != tilt or actualPan != pan: + logging.info("Moving camera to Tilt: %d & Pan: %d"%(tilt, pan)) + actualTilt = tilt + actualPan = pan + lockedOn = True + camera.absolute_move(pan, tilt, cameraZoom, cameraMoveSpeed) + time.sleep(cameraDelay) + get_jpeg_request() + #get_bmp_request() + else: + if actualX != follow_x or actualY != follow_y: + actualX = follow_x + actualY = follow_y + #camera.center_move(actualX, actualY, cameraMoveSpeed) + pan_tilt = str(actualX) + "," + str(actualY) + camera._camera_command({'center': pan_tilt, 'speed': cameraMoveSpeed, 'imagewidth': '1280', 'imageheight': '720'}) + #time.sleep(cameraDelay) #if lockedOn == True: # filename = "capture/{}_{}".format(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), currentPlane) # camera.capture("{}.jpeg".format(filename)) @@ -163,6 +238,8 @@ def moveCamera(): ############################################# def on_message(client, userdata, message): global currentPlane + global object_timeout + command = str(message.payload.decode("utf-8")) #rint(command) try: @@ -179,10 +256,17 @@ def on_message(client, userdata, message): except: print("Caught it!") - logging.info("{}\tBearing: {} \tElevation: {}".format(update["icao24"],update["bearing"],update["elevation"])) - bearingGood = setPan(update["bearing"]) - setTilt(update["elevation"]) - currentPlane = update + if message.topic == object_topic: + logging.info("Got Object Topic") + setXY(update["x"], update["y"]) + object_timeout = time.mktime(time.gmtime()) + 5 + elif message.topic == flight_topic: + logging.info("{}\tBearing: {} \tElevation: {}".format(update["icao24"],update["bearing"],update["elevation"])) + bearingGood = setPan(update["bearing"]) + setTilt(update["elevation"]) + currentPlane = update + else: + logging.info("Message: {} Object: {} Flight: {}".format(message.topic, object_topic, flight_topic)) def main(): global args @@ -194,12 +278,15 @@ def main(): global cameraMoveSpeed global cameraZoom global cameraConfig + global flight_topic + global object_topic parser = argparse.ArgumentParser(description='An MQTT based camera controller') parser.add_argument('-b', '--bearing', help="What bearing is the font of the PI pointed at (0-360)", default=0) parser.add_argument('-m', '--mqtt-host', help="MQTT broker hostname", default='127.0.0.1') - parser.add_argument('-t', '--mqtt-topic', help="MQTT topic to subscribe to", default="SkyScan") + parser.add_argument('-t', '--mqtt-flight-topic', help="MQTT topic to subscribe to", default="skyscan/flight/json") + parser.add_argument( '--mqtt-object-topic', help="MQTT topic to subscribe to", default="skyscan/object/json") parser.add_argument('-u', '--axis-username', help="Username for the Axis camera", required=True) parser.add_argument('-p', '--axis-password', help="Password for the Axis camera", required=True) parser.add_argument('-a', '--axis-ip', help="IP address for the Axis camera", required=True) @@ -237,14 +324,17 @@ def main(): threading.Thread(target = moveCamera, daemon = True).start() # Sleep for a bit so we're not hammering the HAT with updates time.sleep(0.005) - print("connecting to MQTT broker at "+ args.mqtt_host+", channel '"+args.mqtt_topic+"'") + flight_topic=args.mqtt_flight_topic + object_topic = args.mqtt_object_topic + print("connecting to MQTT broker at "+ args.mqtt_host+", channel '"+flight_topic+"'") client = mqtt.Client("skyscan-axis-ptz-camera-" + ID) #create new instance client.on_message=on_message #attach function to callback client.connect(args.mqtt_host) #connect to broker client.loop_start() #start the loop - client.subscribe(args.mqtt_topic+"/#") + client.subscribe(flight_topic+"/#") + client.subscribe(object_topic+"/#") client.publish("skyscan/registration", "skyscan-axis-ptz-camera-"+ID+" Registration", 0, False) ############################################# diff --git a/docker-compose.yml b/docker-compose.yml index c41ba14..b148c03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ services: tracker: build: ./tracker - entrypoint: "./flighttracker.py -m mqtt -l ${LAT} -L ${LONG} -a ${ALT} -P skyscan/planes/json -T skyscan/tracking/json -M ${MIN_ELEVATION} -c ${CAMERA_LEAD}" + entrypoint: "./flighttracker.py -m mqtt -l ${LAT} -L ${LONG} -a ${ALT} -P skyscan/planes/json -T skyscan/flight/json -M ${MIN_ELEVATION} -c ${CAMERA_LEAD}" depends_on: - mqtt restart: unless-stopped @@ -21,7 +21,7 @@ services: pan-tilt-pi: build: ./pan-tilt-pi - entrypoint: "./camera.py -m mqtt -t skyscan/tracking/json" + entrypoint: "./camera.py -m mqtt -t skyscan/flight/json" volumes: - /opt/vc:/opt/vc - ./capture:/app/capture @@ -36,7 +36,7 @@ services: axis-ptz: build: ./axis-ptz - entrypoint: "./camera.py -m mqtt -t skyscan/tracking/json -u ${AXIS_USERNAME} -p ${AXIS_PASSWORD} -a ${AXIS_IP} -z ${CAMERA_ZOOM} -s ${CAMERA_MOVE_SPEED} -d ${CAMERA_DELAY}" + entrypoint: "./camera.py -m mqtt -t skyscan/flight/json -u ${AXIS_USERNAME} -p ${AXIS_PASSWORD} -a ${AXIS_IP} -z ${CAMERA_ZOOM} -s ${CAMERA_MOVE_SPEED} -d ${CAMERA_DELAY}" volumes: - ./capture:/app/capture depends_on: diff --git a/object-tracker/.gitignore b/object-tracker/.gitignore new file mode 100644 index 0000000..9b55d7c --- /dev/null +++ b/object-tracker/.gitignore @@ -0,0 +1,2 @@ +all_models/ +**__pycache__ diff --git a/object-tracker/CONTRIBUTING.md b/object-tracker/CONTRIBUTING.md new file mode 100644 index 0000000..939e534 --- /dev/null +++ b/object-tracker/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google.com/conduct/). diff --git a/object-tracker/LICENSE b/object-tracker/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/object-tracker/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/object-tracker/README.md b/object-tracker/README.md new file mode 100644 index 0000000..12c3b31 --- /dev/null +++ b/object-tracker/README.md @@ -0,0 +1,74 @@ +# Edge TPU Object Tracker Example + +This repo contains a collection of examples that use camera streams +together with the [TensorFlow Lite API](https://tensorflow.org/lite) with a +Coral device such as the +[USB Accelerator](https://coral.withgoogle.com/products/accelerator) or +[Dev Board](https://coral.withgoogle.com/products/dev-board) and provides an Object tracker for use with the detected objects. + + +## Installation + +1. First, be sure you have completed the [setup instructions for your Coral + device](https://coral.ai/docs/setup/). If it's been a while, repeat to be sure + you have the latest software. + + Importantly, you should have the latest TensorFlow Lite runtime installed + (as per the [Python quickstart]( + https://www.tensorflow.org/lite/guide/python)). + +2. Clone this Git repo onto your computer: + + ``` + mkdir google-coral && cd google-coral + + git clone https://github.com/google-coral/example-object-tracker.git + + cd example-object-tracker/ + ``` + +3. Download the models: + + ``` + sh download_models.sh + ``` + + These models will be downloaded to a new folder + ```models```. + + +Further requirements may be needed by the different camera libraries, check the +README file for the respective subfolder. + +## Contents + + * __gstreamer__: Python examples using gstreamer to obtain camera streem. These + examples work on Linux using a webcam, Raspberry Pi with + the Raspicam and on the Coral DevBoard using the Coral camera. For the + former two you will also need a Coral USB Accelerator to run the models. + + This demo provides the support of an Object tracker. After following the setup + instructions in README file for the subfolder ```gstreamer```, you can run the tracker demo: + + ``` + cd gstreamer + python3 detect.py --tracker sort + ``` + +## Models + +For the demos in this repository you can change the model and the labels +file by using the flags flags ```--model``` and +```--labels```. Be sure to use the models labeled _edgetpu, as those are +compiled for the accelerator - otherwise the model will run on the CPU and +be much slower. + + +For detection you need to select one of the SSD detection models +and its corresponding labels file: + +``` +mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite, coco_labels.txt +``` + + diff --git a/object-tracker/download_models.sh b/object-tracker/download_models.sh new file mode 100644 index 0000000..23b8a4a --- /dev/null +++ b/object-tracker/download_models.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mkdir -p models +wget https://dl.google.com/coral/canned_models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite +wget https://dl.google.com/coral/canned_models/coco_labels.txt +mv mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite coco_labels.txt models/ diff --git a/object-tracker/gstreamer/README.md b/object-tracker/gstreamer/README.md new file mode 100755 index 0000000..6c62791 --- /dev/null +++ b/object-tracker/gstreamer/README.md @@ -0,0 +1,40 @@ +# GStreamer based Object Tracking Example + +This folder contains example code using [GStreamer](https://github.com/GStreamer/gstreamer) to +obtain camera images and perform image classification and object detection on the Edge TPU. + +This code works on Linux using a webcam, Raspberry Pi with the Pi Camera, and on the Coral Dev +Board using the Coral Camera or a webcam. For the first two, you also need a Coral +USB/PCIe/M.2 Accelerator. + + +## Set up your device + +1. First, be sure you have completed the [setup instructions for your Coral + device](https://coral.ai/docs/setup/). If it's been a while, repeat to be sure + you have the latest software. + + Importantly, you should have the latest TensorFlow Lite runtime installed + (as per the [Python quickstart]( + https://www.tensorflow.org/lite/guide/python)). You can check which version is installed + using the ```pip3 show tflite_runtime``` command. + +2. Install the GStreamer libraries and Trackers: + + ``` + bash install_requirements.sh + ``` +3. Run the detection model with Sort tracker + ``` + python3 detect.py --tracker sort + ``` + +## Run the detection demo without any tracker (SSD models) + +``` +python3 detect.py +``` +You can change the model and the labels file using ```--model``` and ```--labels```. + +By default, example use the attached Coral Camera. If you want to use a USB camera, +edit the ```gstreamer.py``` file and change ```device=/dev/video0``` to ```device=/dev/video1```. diff --git a/object-tracker/gstreamer/common.py b/object-tracker/gstreamer/common.py new file mode 100755 index 0000000..90b7d52 --- /dev/null +++ b/object-tracker/gstreamer/common.py @@ -0,0 +1,75 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities.""" +import collections +import gi +gi.require_version('Gst', '1.0') +from gi.repository import Gst +import numpy as np +import svgwrite +import tflite_runtime.interpreter as tflite +import time + +EDGETPU_SHARED_LIB = 'libedgetpu.so.1' + +def make_interpreter(model_file): + model_file, *device = model_file.split('@') + return tflite.Interpreter( + model_path=model_file, + experimental_delegates=[ + tflite.load_delegate(EDGETPU_SHARED_LIB, + {'device': device[0]} if device else {}) + ]) + +def input_image_size(interpreter): + """Returns input size as (width, height, channels) tuple.""" + _, height, width, channels = interpreter.get_input_details()[0]['shape'] + return width, height, channels + +def input_tensor(interpreter): + """Returns input tensor view as numpy array of shape (height, width, channels).""" + tensor_index = interpreter.get_input_details()[0]['index'] + return interpreter.tensor(tensor_index)()[0] + +def set_input(interpreter, buf): + """Copies data to input tensor.""" + result, mapinfo = buf.map(Gst.MapFlags.READ) + if result: + np_buffer = np.reshape(np.frombuffer(mapinfo.data, dtype=np.uint8), + interpreter.get_input_details()[0]['shape']) + input_tensor(interpreter)[:, :] = np_buffer + buf.unmap(mapinfo) + +def output_tensor(interpreter, i): + """Returns dequantized output tensor if quantized before.""" + output_details = interpreter.get_output_details()[i] + output_data = np.squeeze(interpreter.tensor(output_details['index'])()) + if 'quantization' not in output_details: + return output_data + scale, zero_point = output_details['quantization'] + if scale == 0: + return output_data - zero_point + return scale * (output_data - zero_point) + +def avg_fps_counter(window_size): + window = collections.deque(maxlen=window_size) + prev = time.monotonic() + yield 0.0 # First fps value. + + while True: + curr = time.monotonic() + window.append(curr - prev) + prev = curr + yield len(window) / sum(window) diff --git a/object-tracker/gstreamer/detect.py b/object-tracker/gstreamer/detect.py new file mode 100755 index 0000000..a9e446e --- /dev/null +++ b/object-tracker/gstreamer/detect.py @@ -0,0 +1,218 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A demo which runs object detection on camera frames using GStreamer. +It also provides support for Object Tracker. + +Run default object detection: +python3 detect.py + +Choose different camera and input encoding +python3 detect.py --videosrc /dev/video1 --videofmt jpeg + +Choose an Object Tracker. Example : To run sort tracker +python3 detect.py --tracker sort + +TEST_DATA=../all_models + +Run coco model: +python3 detect.py \ + --model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \ + --labels ${TEST_DATA}/coco_labels.txt +""" +import argparse +import collections +import common +import gstreamer +import numpy as np +import os +import re +import svgwrite +import time +from tracker import ObjectTracker + + +Object = collections.namedtuple('Object', ['id', 'score', 'bbox']) + + +def load_labels(path): + p = re.compile(r'\s*(\d+)(.+)') + with open(path, 'r', encoding='utf-8') as f: + lines = (p.match(line).groups() for line in f.readlines()) + return {int(num): text.strip() for num, text in lines} + + +def shadow_text(dwg, x, y, text, font_size=20): + dwg.add(dwg.text(text, insert=(x+1, y+1), fill='black', font_size=font_size)) + dwg.add(dwg.text(text, insert=(x, y), fill='white', font_size=font_size)) + + +def generate_svg(src_size, inference_size, inference_box, objs, labels, text_lines, trdata, trackerFlag): + dwg = svgwrite.Drawing('', size=src_size) + src_w, src_h = src_size + inf_w, inf_h = inference_size + box_x, box_y, box_w, box_h = inference_box + scale_x, scale_y = src_w / box_w, src_h / box_h + + for y, line in enumerate(text_lines, start=1): + shadow_text(dwg, 10, y*20, line) + if trackerFlag and (np.array(trdata)).size: + for td in trdata: + x0, y0, x1, y1, trackID = td[0].item(), td[1].item( + ), td[2].item(), td[3].item(), td[4].item() + overlap = 0 + for ob in objs: + dx0, dy0, dx1, dy1 = ob.bbox.xmin.item(), ob.bbox.ymin.item( + ), ob.bbox.xmax.item(), ob.bbox.ymax.item() + area = (min(dx1, x1)-max(dx0, x0))*(min(dy1, y1)-max(dy0, y0)) + if (area > overlap): + overlap = area + obj = ob + + # Relative coordinates. + x, y, w, h = x0, y0, x1 - x0, y1 - y0 + # Absolute coordinates, input tensor space. + x, y, w, h = int(x * inf_w), int(y * + inf_h), int(w * inf_w), int(h * inf_h) + # Subtract boxing offset. + x, y = x - box_x, y - box_y + # Scale to source coordinate space. + x, y, w, h = x * scale_x, y * scale_y, w * scale_x, h * scale_y + percent = int(100 * obj.score) + label = '{}% {} ID:{}'.format( + percent, labels.get(obj.id, obj.id), int(trackID)) + shadow_text(dwg, x, y - 5, label) + dwg.add(dwg.rect(insert=(x, y), size=(w, h), + fill='none', stroke='red', stroke_width='2')) + else: + for obj in objs: + x0, y0, x1, y1 = list(obj.bbox) + # Relative coordinates. + x, y, w, h = x0, y0, x1 - x0, y1 - y0 + # Absolute coordinates, input tensor space. + x, y, w, h = int(x * inf_w), int(y * + inf_h), int(w * inf_w), int(h * inf_h) + # Subtract boxing offset. + x, y = x - box_x, y - box_y + # Scale to source coordinate space. + x, y, w, h = x * scale_x, y * scale_y, w * scale_x, h * scale_y + percent = int(100 * obj.score) + label = '{}% {}'.format(percent, labels.get(obj.id, obj.id)) + shadow_text(dwg, x, y - 5, label) + dwg.add(dwg.rect(insert=(x, y), size=(w, h), + fill='none', stroke='red', stroke_width='2')) + return dwg.tostring() + + +class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])): + """Bounding box. + Represents a rectangle which sides are either vertical or horizontal, parallel + to the x or y axis. + """ + __slots__ = () + + +def get_output(interpreter, score_threshold, top_k, image_scale=1.0): + """Returns list of detected objects.""" + boxes = common.output_tensor(interpreter, 0) + category_ids = common.output_tensor(interpreter, 1) + scores = common.output_tensor(interpreter, 2) + + def make(i): + ymin, xmin, ymax, xmax = boxes[i] + return Object( + id=int(category_ids[i]), + score=scores[i], + bbox=BBox(xmin=np.maximum(0.0, xmin), + ymin=np.maximum(0.0, ymin), + xmax=np.minimum(1.0, xmax), + ymax=np.minimum(1.0, ymax))) + return [make(i) for i in range(top_k) if scores[i] >= score_threshold] + + +def main(): + default_model_dir = '../models' + default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite' + default_labels = 'coco_labels.txt' + parser = argparse.ArgumentParser() + parser.add_argument('--model', help='.tflite model path', + default=os.path.join(default_model_dir, default_model)) + parser.add_argument('--labels', help='label file path', + default=os.path.join(default_model_dir, default_labels)) + parser.add_argument('--top_k', type=int, default=3, + help='number of categories with highest score to display') + parser.add_argument('--threshold', type=float, default=0.1, + help='classifier score threshold') + parser.add_argument('--videosrc', help='Which video source to use. ', + default='/dev/video0') + parser.add_argument('--videofmt', help='Input video format.', + default='raw', + choices=['raw', 'h264', 'jpeg']) + parser.add_argument('--tracker', help='Name of the Object Tracker To be used.', + default=None, + choices=[None, 'sort']) + args = parser.parse_args() + + print('Loading {} with {} labels.'.format(args.model, args.labels)) + interpreter = common.make_interpreter(args.model) + interpreter.allocate_tensors() + labels = load_labels(args.labels) + + w, h, _ = common.input_image_size(interpreter) + inference_size = (w, h) + # Average fps over last 30 frames. + fps_counter = common.avg_fps_counter(30) + + def user_callback(input_tensor, src_size, inference_box, mot_tracker): + nonlocal fps_counter + start_time = time.monotonic() + common.set_input(interpreter, input_tensor) + interpreter.invoke() + # For larger input image sizes, use the edgetpu.classification.engine for better performance + objs = get_output(interpreter, args.threshold, args.top_k) + end_time = time.monotonic() + detections = [] # np.array([]) + for n in range(0, len(objs)): + element = [] # np.array([]) + element.append(objs[n].bbox.xmin) + element.append(objs[n].bbox.ymin) + element.append(objs[n].bbox.xmax) + element.append(objs[n].bbox.ymax) + element.append(objs[n].score) # print('element= ',element) + detections.append(element) # print('dets: ',dets) + # convert to numpy array # print('npdets: ',dets) + detections = np.array(detections) + trdata = [] + trackerFlag = False + if detections.any(): + if mot_tracker != None: + trdata = mot_tracker.update(detections) + trackerFlag = True + text_lines = [ + 'Inference: {:.2f} ms'.format((end_time - start_time) * 1000), + 'FPS: {} fps'.format(round(next(fps_counter))), ] + if len(objs) != 0: + return generate_svg(src_size, inference_size, inference_box, objs, labels, text_lines, trdata, trackerFlag) + + result = gstreamer.run_pipeline(user_callback, + src_size=(640, 480), + appsink_size=inference_size, + trackerName=args.tracker, + videosrc=args.videosrc, + videofmt=args.videofmt) + + +if __name__ == '__main__': + main() diff --git a/object-tracker/gstreamer/gstreamer.py b/object-tracker/gstreamer/gstreamer.py new file mode 100755 index 0000000..13dcd05 --- /dev/null +++ b/object-tracker/gstreamer/gstreamer.py @@ -0,0 +1,276 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import svgwrite +import threading +from tracker import ObjectTracker + +import gi +gi.require_version('Gst', '1.0') +gi.require_version('GstBase', '1.0') +gi.require_version('Gtk', '3.0') +from gi.repository import GLib, GObject, Gst, GstBase, Gtk + +GObject.threads_init() +Gst.init(None) + +class GstPipeline: + def __init__(self, pipeline, user_function, src_size, mot_tracker): + self.user_function = user_function + self.running = False + self.gstbuffer = None + self.sink_size = None + self.src_size = src_size + self.box = None + self.condition = threading.Condition() + self.mot_tracker = mot_tracker + self.pipeline = Gst.parse_launch(pipeline) + self.overlay = self.pipeline.get_by_name('overlay') + self.overlaysink = self.pipeline.get_by_name('overlaysink') + appsink = self.pipeline.get_by_name('appsink') + appsink.connect('new-sample', self.on_new_sample) + + # Set up a pipeline bus watch to catch errors. + bus = self.pipeline.get_bus() + bus.add_signal_watch() + bus.connect('message', self.on_bus_message) + + # Set up a full screen window on Coral, no-op otherwise. + self.setup_window() + + def run(self): + # Start inference worker. + self.running = True + worker = threading.Thread(target=self.inference_loop) + worker.start() + + # Run pipeline. + self.pipeline.set_state(Gst.State.PLAYING) + try: + Gtk.main() + except: + pass + + # Clean up. + self.pipeline.set_state(Gst.State.NULL) + while GLib.MainContext.default().iteration(False): + pass + with self.condition: + self.running = False + self.condition.notify_all() + worker.join() + + def on_bus_message(self, bus, message): + t = message.type + if t == Gst.MessageType.EOS: + Gtk.main_quit() + elif t == Gst.MessageType.WARNING: + err, debug = message.parse_warning() + sys.stderr.write('Warning: %s: %s\n' % (err, debug)) + elif t == Gst.MessageType.ERROR: + err, debug = message.parse_error() + sys.stderr.write('Error: %s: %s\n' % (err, debug)) + Gtk.main_quit() + return True + + def on_new_sample(self, sink): + sample = sink.emit('pull-sample') + if not self.sink_size: + s = sample.get_caps().get_structure(0) + self.sink_size = (s.get_value('width'), s.get_value('height')) + with self.condition: + self.gstbuffer = sample.get_buffer() + self.condition.notify_all() + return Gst.FlowReturn.OK + + def get_box(self): + if not self.box: + glbox = self.pipeline.get_by_name('glbox') + if glbox: + glbox = glbox.get_by_name('filter') + box = self.pipeline.get_by_name('box') + assert glbox or box + assert self.sink_size + if glbox: + self.box = (glbox.get_property('x'), glbox.get_property('y'), + glbox.get_property('width'), glbox.get_property('height')) + else: + self.box = (-box.get_property('left'), -box.get_property('top'), + self.sink_size[0] + box.get_property('left') + box.get_property('right'), + self.sink_size[1] + box.get_property('top') + box.get_property('bottom')) + return self.box + + def inference_loop(self): + while True: + with self.condition: + while not self.gstbuffer and self.running: + self.condition.wait() + if not self.running: + break + gstbuffer = self.gstbuffer + self.gstbuffer = None + + # Passing Gst.Buffer as input tensor avoids 2 copies of it: + # * Python bindings copies the data when mapping gstbuffer + # * Numpy copies the data when creating ndarray. + # This requires a recent version of the python3-edgetpu package. If this + # raises an exception please make sure dependencies are up to date. + input_tensor = gstbuffer + svg = self.user_function(input_tensor, self.src_size, self.get_box(), self.mot_tracker) + if svg: + if self.overlay: + self.overlay.set_property('data', svg) + if self.overlaysink: + self.overlaysink.set_property('svg', svg) + + def setup_window(self): + # Only set up our own window if we have Coral overlay sink in the pipeline. + if not self.overlaysink: + return + + gi.require_version('GstGL', '1.0') + gi.require_version('GstVideo', '1.0') + from gi.repository import GstGL, GstVideo + + # Needed to commit the wayland sub-surface. + def on_gl_draw(sink, widget): + widget.queue_draw() + + # Needed to account for window chrome etc. + def on_widget_configure(widget, event, overlaysink): + allocation = widget.get_allocation() + overlaysink.set_render_rectangle(allocation.x, allocation.y, + allocation.width, allocation.height) + return False + + window = Gtk.Window(Gtk.WindowType.TOPLEVEL) + window.fullscreen() + + drawing_area = Gtk.DrawingArea() + window.add(drawing_area) + drawing_area.realize() + + self.overlaysink.connect('drawn', on_gl_draw, drawing_area) + + # Wayland window handle. + wl_handle = self.overlaysink.get_wayland_window_handle(drawing_area) + self.overlaysink.set_window_handle(wl_handle) + + # Wayland display context wrapped as a GStreamer context. + wl_display = self.overlaysink.get_default_wayland_display_context() + self.overlaysink.set_context(wl_display) + + drawing_area.connect('configure-event', on_widget_configure, self.overlaysink) + window.connect('delete-event', Gtk.main_quit) + window.show_all() + + # The appsink pipeline branch must use the same GL display as the screen + # rendering so they get the same GL context. This isn't automatically handled + # by GStreamer as we're the ones setting an external display handle. + def on_bus_message_sync(bus, message, overlaysink): + if message.type == Gst.MessageType.NEED_CONTEXT: + _, context_type = message.parse_context_type() + if context_type == GstGL.GL_DISPLAY_CONTEXT_TYPE: + sinkelement = overlaysink.get_by_interface(GstVideo.VideoOverlay) + gl_context = sinkelement.get_property('context') + if gl_context: + display_context = Gst.Context.new(GstGL.GL_DISPLAY_CONTEXT_TYPE, True) + GstGL.context_set_gl_display(display_context, gl_context.get_display()) + message.src.set_context(display_context) + return Gst.BusSyncReply.PASS + + bus = self.pipeline.get_bus() + bus.set_sync_handler(on_bus_message_sync, self.overlaysink) + +def detectCoralDevBoard(): + try: + if 'MX8MQ' in open('/sys/firmware/devicetree/base/model').read(): + print('Detected Edge TPU dev board.') + return True + except: pass + return False + +def run_pipeline(user_function, + src_size, + appsink_size, + trackerName, + videosrc='/dev/video1', + videofmt='raw'): + objectOfTracker = None + if videofmt == 'h264': + SRC_CAPS = 'video/x-h264,width={width},height={height},framerate=30/1' + elif videofmt == 'jpeg': + SRC_CAPS = 'image/jpeg,width={width},height={height},framerate=30/1' + else: + SRC_CAPS = 'video/x-raw,width={width},height={height},framerate=30/1' + if videosrc.startswith('/dev/video'): + PIPELINE = 'v4l2src device=%s ! {src_caps}'%videosrc + elif videosrc.startswith('http'): + PIPELINE = 'souphttpsrc location=%s'%videosrc + elif videosrc.startswith('rtsp'): + PIPELINE = 'rtspsrc location=%s'%videosrc + else: + demux = 'avidemux' if videosrc.endswith('avi') else 'qtdemux' + PIPELINE = """filesrc location=%s ! %s name=demux demux.video_0 + ! queue ! decodebin ! videorate + ! videoconvert n-threads=4 ! videoscale n-threads=4 + ! {src_caps} ! {leaky_q} """ % (videosrc, demux) + ''' Check for the object tracker.''' + if trackerName != None: + if trackerName == 'mediapipe': + if detectCoralDevBoard(): + objectOfTracker = ObjectTracker('mediapipe') + else: + print("Tracker MediaPipe is only available on the Dev Board. Keeping the tracker as None") + trackerName = None + else: + objectOfTracker = ObjectTracker(trackerName) + else: + pass + + if detectCoralDevBoard(): + scale_caps = None + PIPELINE += """ ! decodebin ! glupload ! tee name=t + t. ! queue ! glfilterbin filter=glbox name=glbox ! {sink_caps} ! {sink_element} + t. ! queue ! glsvgoverlaysink name=overlaysink + """ + else: + scale = min(appsink_size[0] / src_size[0], appsink_size[1] / src_size[1]) + scale = tuple(int(x * scale) for x in src_size) + scale_caps = 'video/x-raw,width={width},height={height}'.format(width=scale[0], height=scale[1]) + PIPELINE += """ ! tee name=t + t. ! {leaky_q} ! videoconvert ! videoscale ! {scale_caps} ! videobox name=box autocrop=true + ! {sink_caps} ! {sink_element} + t. ! {leaky_q} ! videoconvert + ! rsvgoverlay name=overlay ! videoconvert ! ximagesink sync=false + """ + if objectOfTracker: + mot_tracker = objectOfTracker.trackerObject.mot_tracker + else: + mot_tracker = None + SINK_ELEMENT = 'appsink name=appsink emit-signals=true max-buffers=1 drop=true' + SINK_CAPS = 'video/x-raw,format=RGB,width={width},height={height}' + LEAKY_Q = 'queue max-size-buffers=1 leaky=downstream' + + src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1]) + sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1]) + pipeline = PIPELINE.format(leaky_q=LEAKY_Q, + src_caps=src_caps, sink_caps=sink_caps, + sink_element=SINK_ELEMENT, scale_caps=scale_caps) + + print('Gstreamer pipeline:\n', pipeline) + + pipeline = GstPipeline(pipeline, user_function, src_size, mot_tracker) + pipeline.run() diff --git a/object-tracker/gstreamer/install_requirements.sh b/object-tracker/gstreamer/install_requirements.sh new file mode 100755 index 0000000..773016d --- /dev/null +++ b/object-tracker/gstreamer/install_requirements.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if grep -s -q "MX8MQ" /sys/firmware/devicetree/base/model; then + echo "Installing DevBoard specific dependencies" + sudo apt-get install -y python3-pip python3-edgetpuvision + sudo python3 -m pip install svgwrite +else + # Install gstreamer + sudo apt-get install -y gstreamer1.0-plugins-bad gstreamer1.0-plugins-good python3-gst-1.0 python3-gi gir1.2-gtk-3.0 + python3 -m pip install svgwrite + + if grep -s -q "Raspberry Pi" /sys/firmware/devicetree/base/model; then + echo "Installing Raspberry Pi specific dependencies" + sudo apt-get install python3-rpi.gpio + # Add v4l2 video module to kernel + if ! grep -q "bcm2835-v4l2" /etc/modules; then + echo bcm2835-v4l2 | sudo tee -a /etc/modules + fi + sudo modprobe bcm2835-v4l2 + fi +fi + +# Verify models are downloaded +if [ ! -d "../models" ] +then + cd .. + echo "Downloading models." + bash download_models.sh + cd - +fi + +# Install Tracker Dependencies +echo +echo "Installing tracker dependencies." +echo +echo "Note that the trackers have their own licensing, many of which +are not Apache. Care should be taken if using a tracker with restrictive +licenses for end applications." + +read -p "Install SORT (GPLv3)? " -n 1 -r +if [[ $REPLY =~ ^[Yy]$ ]] +then + wget https://github.com/abewley/sort/archive/master.zip -O sort.zip + unzip sort.zip -d ../third_party + rm sort.zip + sudo apt install python3-skimage + python3 -m pip install -r requirements_for_sort_tracker.txt +fi +echo diff --git a/object-tracker/gstreamer/requirements_for_sort_tracker.txt b/object-tracker/gstreamer/requirements_for_sort_tracker.txt new file mode 100644 index 0000000..5b3bd5a --- /dev/null +++ b/object-tracker/gstreamer/requirements_for_sort_tracker.txt @@ -0,0 +1,2 @@ +filterpy==1.1.0 +lap==0.4.0 diff --git a/object-tracker/gstreamer/tracker.py b/object-tracker/gstreamer/tracker.py new file mode 100755 index 0000000..a68e448 --- /dev/null +++ b/object-tracker/gstreamer/tracker.py @@ -0,0 +1,42 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides the support for Tracker Object. +This creates object for the specific tracker based on the name of the tracker provided +in the command line of the demo. + +To add more trackers here, simply replicate the SortTracker() code and replace it with +the new tracker as required. + +Developer simply needs to instantiate the object of ObjectTracker(trackerObjectName) with a valid +trackerObjectName. + +""" +import os,sys + +class ObjectTracker(object): + def __init__(self, trackerObjectName): + if trackerObjectName == 'sort': # Add more trackers in elif whenever needed + self.trackerObject = SortTracker() + else: + print("Invalid Tracker Name") + self.trackerObject = None + + +class SortTracker(ObjectTracker): + def __init__(self): + sys.path.append(os.path.join(os.path.dirname(__file__), '../third_party', 'sort-master')) + from sort import Sort + self.mot_tracker = Sort() diff --git a/object-tracker/opencv/README.md b/object-tracker/opencv/README.md new file mode 100644 index 0000000..d8d8def --- /dev/null +++ b/object-tracker/opencv/README.md @@ -0,0 +1,82 @@ +# OpenCV camera examples with Coral + +This folder contains example code using [OpenCV](https://github.com/opencv/opencv) to obtain +camera images and perform object detection on the Edge TPU. + +This code works on Linux/macOS/Windows using a webcam, Raspberry Pi with the Pi Camera, and on the Coral Dev +Board using the Coral Camera or a webcam. For all settings other than the Coral Dev Board, you also need a Coral +USB/PCIe/M.2 Accelerator. + + +## Set up your device + +1. First, be sure you have completed the [setup instructions for your Coral + device](https://coral.ai/docs/setup/). If it's been a while, repeat to be sure + you have the latest software. + + Importantly, you should have the latest TensorFlow Lite runtime installed + (as per the [Python quickstart]( + https://www.tensorflow.org/lite/guide/python)). You can check which version is installed + using the ```pip3 show tflite_runtime``` command. + +1.5 Install PyCoral: https://coral.ai/software/#pycoral-api + + +2. Clone this Git repo onto your computer or Dev Board: + + ``` + mkdir google-coral && cd google-coral + + git clone https://github.com/google-coral/examples-camera --depth 1 + ``` + +3. Download the models: + + ``` + cd examples-camera + + sh download_models.sh + ``` + +4. Install the OpenCV libraries: + + ``` + cd opencv + + bash install_requirements.sh + ``` + + +## Run the detection model with Sort tracker +``` +python3 detect.py --tracker sort +``` + +## Run the detection demo without any tracker (SSD models) + +``` +python3 detect.py +``` + +## Arguments + +*All of the arguments are optional and provide increasing control over the configuration* + + - **model** path to the model you want to use, defaults to COCO + - **labels** labels for the model you are using, default to COCO labels + - **top_k** number of categories with highest score to display, defaults to 3 + - **threshold** classifier score threshold + - **videosrc** what video source you want to use. Choices are `net` or `dev`. Default is `dev`: + - **dev** a directly connected (dev) camera, can be Coral cam or USB cam or Networked + - **net** network video source, using RTSP. The --netsrc argument must be specified. + - **file** a video file can be used as a source + - **camera_idx** Index of which video source to use. I am not sure how OpenCV enumerates them. Defaults to 0. + - **filesrc** the path to the video file. In the Docker container should be at /app/videos + - **netsrc** If the `videosrc` is `net` then specify the URL. Example: `rtsp://192.168.1.43/mpeg4/media.amp` + - **tracker** Name of the Object Tracker To be used. Choices are `None` or `sort`. + +You can change the model and the labels file using ```--model``` and ```--labels```. + +By default, this uses the ```mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite``` model. + +You can change the model and the labels file using flags ```--model``` and ```--labels```. diff --git a/object-tracker/opencv/detect.py b/object-tracker/opencv/detect.py new file mode 100644 index 0000000..a9cbafa --- /dev/null +++ b/object-tracker/opencv/detect.py @@ -0,0 +1,334 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A demo that runs object detection on camera frames using OpenCV. + +TEST_DATA=../models + +Run face detection model: +python3 detect.py \ + --model ${TEST_DATA}/mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite + +Run coco model: +python3 detect.py \ + --model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \ + --labels ${TEST_DATA}/coco_labels.txt + +""" +import argparse +import numpy as np +import cv2 +import os +import random +import time +from PIL import Image + + +from pycoral.adapters import common +from pycoral.adapters.common import input_size +from pycoral.adapters.detect import get_objects +from pycoral.utils.dataset import read_label_file +from pycoral.utils.edgetpu import make_interpreter +from pycoral.utils.edgetpu import run_inference +from tracker import ObjectTracker + +import json +import mqtt_wrapper + +mot_tracker = None +mqtt_bridge = None +mqtt_topic = None +ID = str(random.randint(1,100001)) + + +############################################# +## MQTT Callback Function ## +############################################# +def on_message(client, userdata, message): + global currentPlane + command = str(message.payload.decode("utf-8")) + #rint(command) + try: + update = json.loads(command) + #payload = json.loads(messsage.payload) # you can use json.loads to convert string to json + except JSONDecodeError as e: + # do whatever you want + print(e) + except TypeError as e: + # do whatever you want in this case + print(e) + except ValueError as e: + print(e) + except: + print("Caught it!") + + +def detectCoralDevBoard(): + try: + if 'MX8MQ' in open('/sys/firmware/devicetree/base/model').read(): + print('Detected Edge TPU dev board.') + return True + except: pass + return False + + +Resolution = [1280.0, 720.0] #pixels +Signage = [1.0, -1.0] +GainX = 0.5 +GainY = 0.5 + +def motionControl(x,y): + targetCoordinates = [x,y] + targetCoordinates[0] = (float(targetCoordinates[0]) - (Resolution[0]/2.0))*Signage[0] # X: Convert frame coordinate to center coordinate + targetCoordinates[1] = (float(targetCoordinates[1]) - (Resolution[1]/2.0))*Signage[1] # Y: Convert frame coordinate to center coordinate + targetCoordinates[0]*=GainX # Apply Control Gain in X direction + targetCoordinates[1]*=GainY # Apply Control Gain in Y direction + + targetCoordinates[0] = (float(targetCoordinates[0]) + (Resolution[0]/2.0))*Signage[0] # X: Convert center coordinate to frame coordinate + targetCoordinates[1] = (float(targetCoordinates[1]) - (Resolution[1]/2.0))*Signage[1] # Y: Convert center coordinate to frame coordinate + return targetCoordinates + + +def main(): + global mot_tracker + global mqtt_bridge + global mqtt_topic + + camera_width=1280 + camera_height=720 + + default_model_dir = '../models' + default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite' + default_labels = 'coco_labels.txt' + parser = argparse.ArgumentParser() + parser.add_argument('--model', help='.tflite model path', + default=os.path.join(default_model_dir,default_model)) + parser.add_argument('--labels', help='label file path', + default=os.path.join(default_model_dir, default_labels)) + parser.add_argument('--top_k', type=int, default=3, + help='number of categories with highest score to display') + parser.add_argument('--camera_idx', type=int, help='Index of which video source to use. ', default = 0) + parser.add_argument('--threshold', type=float, default=0.1, + help='classifier score threshold') + parser.add_argument('--tracker', help='Name of the Object Tracker To be used.', + default=None, + choices=[None, 'sort']) + parser.add_argument('--videosrc', help='Directly connected (dev) or Networked (net) video source. ', choices=['dev','net','file'], + default='dev') + parser.add_argument('--display', help='Is a display attached', + default='False', + choices=['True', 'False']) + parser.add_argument('--netsrc', help="Networked video source, example format: rtsp://192.168.1.43/mpeg4/media.amp",) + parser.add_argument('--filesrc', help="Video file source. The videos subdirectory gets mapped into the Docker container, so place your files there.",) + parser.add_argument('--modelInt8', help="Model expects input tensors to be Int8, not UInt8", default='False', choices=['True', 'False']) + parser.add_argument( '--mqtt-host', help="MQTT broker hostname", default='127.0.0.1') + parser.add_argument( '--mqtt-port', type=int, help="MQTT broker port number (default 1883)", default=1883) + parser.add_argument( '--mqtt-topic', dest='mqtt_topic', help="MQTT Object Tracking topic", default="skyscan/object/json") + + args = parser.parse_args() + + trackerName=args.tracker + ''' Check for the object tracker.''' + if trackerName != None: + if trackerName == 'mediapipe': + if detectCoralDevBoard(): + objectOfTracker = ObjectTracker('mediapipe') + else: + print("Tracker MediaPipe is only available on the Dev Board. Keeping the tracker as None") + trackerName = None + else: + objectOfTracker = ObjectTracker(trackerName) + else: + pass + + if trackerName != None and objectOfTracker: + mot_tracker = objectOfTracker.trackerObject.mot_tracker + else: + mot_tracker = None + mqtt_topic = args.mqtt_topic + mqtt_bridge = mqtt_wrapper.bridge(host = args.mqtt_host, port = args.mqtt_port, client_id = "skyscan-object-tracker-%s" % (ID)) + mqtt_bridge.publish("skyscan/registration", "skyscan-adsb-mqtt-"+ID+" Registration", 0, False) + + print('Loading {} with {} labels.'.format(args.model, args.labels)) + interpreter = make_interpreter(args.model) + interpreter.allocate_tensors() + labels = read_label_file(args.labels) + inference_size = input_size(interpreter) + if args.modelInt8=='True': + model_int8 = True + else: + model_int8 = False + + if args.videosrc=='dev': + cap = cv2.VideoCapture(args.camera_idx) + elif args.videosrc=='file': + cap = cv2.VideoCapture(args.filesrc) + else: + if args.netsrc==None: + print("--videosrc was set to net but --netsrc was not specified") + sys.exit() + cap = cv2.VideoCapture(args.netsrc) + + cap.set(cv2.CAP_PROP_BUFFERSIZE, 0) + timeHeartbeat = 0 + while cap.isOpened(): + if timeHeartbeat < time.mktime(time.gmtime()): + timeHeartbeat = time.mktime(time.gmtime()) + 10 + mqtt_bridge.publish("skyscan/heartbeat", "skyscan-object-tracker-" +ID+" Heartbeat", 0, False) + start_time = time.monotonic() + ret, frame = cap.read() + if not ret: + if args.videosrc=='file': + cap = cv2.VideoCapture(args.filesrc) + continue + else: + break + cv2_im = frame + + cv2_im_rgb = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) + cv2_im_rgb = cv2.resize(cv2_im_rgb, inference_size) + + if model_int8: + im_pil = Image.fromarray(cv2_im_rgb) + input_type = common.input_details(interpreter, 'dtype') + img = (input_type(cv2_im_rgb)- 127.5) / 128.0 + + run_inference(interpreter, img.flatten()) + else: + run_inference(interpreter, cv2_im_rgb.tobytes()) + + objs = get_objects(interpreter, args.threshold)[:args.top_k] + height, width, channels = cv2_im.shape + scale_x, scale_y = width / inference_size[0], height / inference_size[1] + detections = [] # np.array([]) + for obj in objs: + bbox = obj.bbox.scale(scale_x, scale_y) + element = [] # np.array([]) + element.append(bbox.xmin) + element.append(bbox.ymin) + element.append(bbox.xmax) + element.append(bbox.ymax) + element.append(obj.score) # print('element= ',element) + element.append(obj.id) + detections.append(element) # print('dets: ',dets) + # convert to numpy array # print('npdets: ',dets) + detections = np.array(detections) + trdata = [] + trackerFlag = False + if detections.any(): + if mot_tracker != None: + trdata = mot_tracker.update(detections) + trackerFlag = True + + cv2_im = append_objs_to_img(cv2_im, detections, labels, trdata, trackerFlag) + follow_x, follow_y = object_to_follow(detections, labels, trdata, trackerFlag) + if args.display == 'True': + cv2.imshow('frame', cv2_im) + + if follow_x != None: + follow_x = int(follow_x * (camera_height/height)) + follow_y = int(follow_y * (camera_width/width)) + coordinates = motionControl(follow_x, follow_y) + follow = { + "x": coordinates[0], + "y": coordinates[1] + } + follow_json = json.dumps(follow) + end_time = time.monotonic() + print("x: {} y:{} new_x: {} new_y: {} Inference: {:.2f} ms".format(follow_x,follow_y, coordinates[0], coordinates[1],(end_time - start_time) * 1000)) + mqtt_bridge.publish(mqtt_topic, follow_json, 0, False) + + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + cap.release() + cv2.destroyAllWindows() + +def object_to_follow( objs, labels, trdata, trackerFlag): + best_score=0 + follow_x, follow_y = None,None + if trackerFlag and (np.array(trdata)).size: + for td in trdata: + x0, y0, x1, y1, trackID = int(td[0].item()), int(td[1].item()), int(td[2].item()), int(td[3].item()), td[4].item() + overlap = 0 + for ob in objs: + dx0, dy0, dx1, dy1 = int(ob[0].item()), int(ob[1].item()), int(ob[2].item()), int(ob[3].item()) + area = (min(dx1, x1)-max(dx0, x0))*(min(dy1, y1)-max(dy0, y0)) + if (area > overlap): + overlap = area + obj = ob + + obj_score = obj[4].item() + + if obj_score > best_score: + best_score = obj_score + + obj_id = int(obj[5].item()) + #print("Tracking - x0: {} y0: {} x1: {} y1: {}".format(x0,y0,x1,y1)) + follow_x = x0 + ((x1 - x0)/2) + follow_y = y0 + ((y1 - y0)/2) + else: + for obj in objs: + x0, y0, x1, y1 = int(obj[0].item()), int(obj[1].item()), int(obj[2].item()), int(obj[3].item()) + obj_score = obj[4].item() + + if obj_score > best_score: + best_score = obj_score + + obj_id = int(obj[5].item()) + #print("Detect - x0: {} y0: {} x1: {} y1: {}".format(x0,y0,x1,y1)) + + follow_x = x0 + ((x1 - x0)/2) + follow_y = y0 + ((y1 - y0)/2) + return follow_x, follow_y + + +def append_objs_to_img(cv2_im, objs, labels, trdata, trackerFlag): + + if trackerFlag and (np.array(trdata)).size: + for td in trdata: + x0, y0, x1, y1, trackID = int(td[0].item()), int(td[1].item()), int(td[2].item()), int(td[3].item()), td[4].item() + overlap = 0 + for ob in objs: + dx0, dy0, dx1, dy1 = int(ob[0].item()), int(ob[1].item()), int(ob[2].item()), int(ob[3].item()) + area = (min(dx1, x1)-max(dx0, x0))*(min(dy1, y1)-max(dy0, y0)) + if (area > overlap): + overlap = area + obj = ob + + obj_score = obj[4].item() + obj_id = int(obj[5].item()) + percent = int(100 * obj_score) + label = '{}% {} ID:{}'.format( + percent, labels.get(obj_id, obj_id), int(trackID)) + cv2_im = cv2.rectangle(cv2_im, (x0, y0), (x1, y1), (0, 255, 0), 2) + cv2_im = cv2.putText(cv2_im, label, (x0, y0+30), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), 2) + + else: + for obj in objs: + x0, y0, x1, y1 = int(obj[0].item()), int(obj[1].item()), int(obj[2].item()), int(obj[3].item()) + obj_score = obj[4].item() + obj_id = int(obj[5].item()) + + percent = int(100 * obj_score) + label = '{}% {}'.format(percent, labels.get(obj_id, obj_id)) + + cv2_im = cv2.rectangle(cv2_im, (x0, y0), (x1, y1), (0, 255, 0), 2) + cv2_im = cv2.putText(cv2_im, label, (x0, y0+30), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), 2) + return cv2_im + +if __name__ == '__main__': + main() diff --git a/object-tracker/opencv/install_requirements.sh b/object-tracker/opencv/install_requirements.sh new file mode 100755 index 0000000..ddcf4e1 --- /dev/null +++ b/object-tracker/opencv/install_requirements.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if grep -s -q "Mendel" /etc/os-release; then + MENDEL_VER="$(cat /etc/mendel_version)" + if [[ "$MENDEL_VER" == "1.0" || "$MENDEL_VER" == "2.0" || "$MENDEL_VER" == "3.0" ]]; then + echo "Your version of Mendel is not compatible with OpenCV." + echo "You must upgrade to Mendel 4.0 or higher." + exit 1 + fi + sudo apt install python3-opencv + sudo pip3 install paho-mqtt +elif grep -s -q "Raspberry Pi" /sys/firmware/devicetree/base/model; then + RASPBIAN=$(grep VERSION_ID /etc/os-release | sed 's/VERSION_ID="\([0-9]\+\)"/\1/') + echo "Raspbian Version: $RASPBIAN" + if [[ "$RASPBIAN" -ge "10" ]]; then + # Lock to version due to bug: https://github.com/piwheels/packages/issues/59 + sudo pip3 install opencv-contrib-python==4.1.0.25 paho-mqtt + sudo apt-get -y install libjasper1 libhdf5-1* libqtgui4 libatlas-base-dev libqt4-test + else + echo "For Raspbian versions older than Buster (10) you have to build OpenCV yourself" + echo "or install the unofficial opencv-contrib-python package." + exit 1 + fi +else + sudo apt install python3-opencv +fi + +# Verify models are downloaded +if [ ! -d "../models" ] +then + cd .. + echo "Downloading models." + bash download_models.sh + cd - +fi + +# Install Tracker Dependencies +echo +echo "Installing tracker dependencies." +echo +echo "Note that the trackers have their own licensing, many of which +are not Apache. Care should be taken if using a tracker with restrictive +licenses for end applications." + +read -p "Install SORT (GPLv3)? " -n 1 -r +if [[ $REPLY =~ ^[Yy]$ ]] +then + wget https://github.com/abewley/sort/archive/master.zip -O sort.zip + unzip sort.zip -d ../third_party + rm sort.zip + sudo apt install python3-skimage + python3 -m pip install -r requirements_for_sort_tracker.txt +fi +echo diff --git a/object-tracker/opencv/mqtt_wrapper/__init__.py b/object-tracker/opencv/mqtt_wrapper/__init__.py new file mode 100644 index 0000000..4822aba --- /dev/null +++ b/object-tracker/opencv/mqtt_wrapper/__init__.py @@ -0,0 +1 @@ +from .bridge import bridge \ No newline at end of file diff --git a/object-tracker/opencv/mqtt_wrapper/bridge.py b/object-tracker/opencv/mqtt_wrapper/bridge.py new file mode 100644 index 0000000..069db8b --- /dev/null +++ b/object-tracker/opencv/mqtt_wrapper/bridge.py @@ -0,0 +1,102 @@ +#!/usr/bin/python +import paho.mqtt.client as mqtt +import time +import traceback + +class bridge: + + def __init__(self, mqtt_topic = None, client_id = "bridge", user_id = None, password = None, host = "127.0.0.1", port = 1883, keepalive = 60): + self.mqtt_topic = mqtt_topic + self.client_id = client_id + self.user_id = user_id + self.password = password + self.host = host + self.port = port + self.keepalive = keepalive + + self.disconnect_flag = False + self.rc = 1 + self.timeout = 0 + + self.client = mqtt.Client(self.client_id, clean_session = True) + if self.user_id and self.password: + self.client.username_pw_set(self.user_id, self.password) + + self.client.on_connect = self.on_connect + self.client.on_disconnect = self.on_disconnect + self.client.on_message = self.on_message + self.client.on_unsubscribe = self.on_unsubscribe + self.client.on_subscribe = self.on_subscribe + self.client.on_publish = self.on_publish + + self.connect() + + def connect(self): + while self.rc != 0: + try: + self.rc = self.client.connect(self.host, self.port, self.keepalive) + except Exception as e: + print("connection failed") + time.sleep(2) + self.timeout = self.timeout + 2 + + def msg_process(self, msg): + pass + + def looping(self, loop_timeout = .1): + self.client.loop(loop_timeout) + + def on_connect(self, client, userdata, flags, rc): + print("Connected with result code "+str(rc)) + if self.mqtt_topic: + self.client.subscribe(self.mqtt_topic) + self.timeout = 0 + + def on_disconnect(self, client, userdata, rc): + if rc != 0: + if not self.disconnect_flag: + print("Unexpected disconnection.") + print("Trying reconnection") + self.rc = rc + self.connect() + + def on_message(self, client, userdata, msg): + try: + self.msg_process(msg) + except Exception as e: + print(traceback.format_exc()) + + def unsubscribe(self): + print(" unsubscribing") + self.client.unsubscribe(self.mqtt_topic) + + def disconnect(self): + print(" disconnecting") + self.disconnect_flag = True + self.client.disconnect() + + def on_unsubscribe(self, client, userdata, mid): + if (self.mqtt_topic == '#'): + print("Unsubscribed to all the topics" ) + else: + print("Unsubscribed to '%s'" % self.mqtt_topic) + + def on_subscribe(self, client, userdata, mid, granted_qos): + if (self.mqtt_topic == '#'): + print("Subscribed to all the topics" ) + else: + print("Subscribed to '%s'" % self.mqtt_topic) + + def on_publish(self, client, userdata, mid): + pass + + def hook(self): + self.unsubscribe() + self.disconnect() + print(" shutting down") + + def get_timeout(self): + return self.timeout + + def publish(self, topic, payload = None, qos = 0, retain = False): + self.client.publish(topic, payload, qos, retain) diff --git a/object-tracker/opencv/requirements_for_sort_tracker.txt b/object-tracker/opencv/requirements_for_sort_tracker.txt new file mode 100644 index 0000000..5b3bd5a --- /dev/null +++ b/object-tracker/opencv/requirements_for_sort_tracker.txt @@ -0,0 +1,2 @@ +filterpy==1.1.0 +lap==0.4.0 diff --git a/object-tracker/opencv/tracker.py b/object-tracker/opencv/tracker.py new file mode 100755 index 0000000..a68e448 --- /dev/null +++ b/object-tracker/opencv/tracker.py @@ -0,0 +1,42 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides the support for Tracker Object. +This creates object for the specific tracker based on the name of the tracker provided +in the command line of the demo. + +To add more trackers here, simply replicate the SortTracker() code and replace it with +the new tracker as required. + +Developer simply needs to instantiate the object of ObjectTracker(trackerObjectName) with a valid +trackerObjectName. + +""" +import os,sys + +class ObjectTracker(object): + def __init__(self, trackerObjectName): + if trackerObjectName == 'sort': # Add more trackers in elif whenever needed + self.trackerObject = SortTracker() + else: + print("Invalid Tracker Name") + self.trackerObject = None + + +class SortTracker(ObjectTracker): + def __init__(self): + sys.path.append(os.path.join(os.path.dirname(__file__), '../third_party', 'sort-master')) + from sort import Sort + self.mot_tracker = Sort() diff --git a/object-tracker/third_party/README.md b/object-tracker/third_party/README.md new file mode 100644 index 0000000..1f610fd --- /dev/null +++ b/object-tracker/third_party/README.md @@ -0,0 +1,7 @@ +# Third Party Trackers +This directory contains third party trackers. + +While the overall project has a permissive Apache 2.0 license, certain trackers +(especially those intended for research) have more restrictive licenses. These +trackers will be downloaded in install_requirements scripts and will notify of +the license used. Care should be taken if restrictive licenses are an issue. diff --git a/tracker/flighttracker.py b/tracker/flighttracker.py index 1970d40..493b12b 100755 --- a/tracker/flighttracker.py +++ b/tracker/flighttracker.py @@ -296,7 +296,7 @@ class FlightTracker(object): __mqtt_broker: str = "" __mqtt_port: int = 0 __plane_topic: str = None - __tracking_topic: str = None + __flight_topic: str = None __client = None __observations: Dict[str, str] = {} __tracking_icao24: str = None @@ -304,7 +304,7 @@ class FlightTracker(object): __next_clean: datetime = None __has_nagged: bool = False - def __init__(self, mqtt_broker: str, plane_topic: str, tracking_topic: str, mqtt_port: int = 1883, ): + def __init__(self, mqtt_broker: str, plane_topic: str, flight_topic: str, mqtt_port: int = 1883, ): """Initialize the flight tracker Arguments: @@ -313,7 +313,7 @@ def __init__(self, mqtt_broker: str, plane_topic: str, tracking_topic: str, mqt latitude {float} -- Latitude of receiver longitude {float} -- Longitude of receiver plane_topic {str} -- MQTT topic for plane reports - tracking_topic {str} -- MQTT topic for current tracking report + flight_topic {str} -- MQTT topic for current tracking report Keyword Arguments: dump1090_port {int} -- Override the dump1090 raw port (default: {30003}) @@ -326,7 +326,7 @@ def __init__(self, mqtt_broker: str, plane_topic: str, tracking_topic: str, mqt self.__observations = {} self.__next_clean = datetime.utcnow() + timedelta(seconds=OBSERVATION_CLEAN_INTERVAL) self.__plane_topic = plane_topic - self.__tracking_topic = tracking_topic + self.__flight_topic = flight_topic def __publish_thread(self): @@ -356,7 +356,7 @@ def __publish_thread(self): elevation = utils.elevation(distance, cur.getAltitude(), camera_altitude) # we need to convert to feet because the altitude is in feet retain = False - self.__client.publish(self.__tracking_topic, cur.json(bearing, distance, elevation), 0, retain) + self.__client.publish(self.__flight_topic, cur.json(bearing, distance, elevation), 0, retain) logging.info("%s at %5d brg %3d alt %5d trk %3d spd %3d %s" % (cur.getIcao24(), distance, bearing, cur.getAltitude(), cur.getTrack(), cur.getGroundSpeed(), cur.getType())) if distance < 3000: @@ -377,7 +377,7 @@ def updateTrackingDistance(self): def run(self): """Run the flight tracker. """ - print("connecting to MQTT broker at "+ self.__mqtt_broker +", subcribing on channel '"+ self.__plane_topic+"'publising on: " + self.__tracking_topic) + print("connecting to MQTT broker at "+ self.__mqtt_broker +", subcribing on channel '"+ self.__plane_topic+"'publising on: " + self.__flight_topic) self.__client = mqtt.Client("skyscan-tracker-" + ID) #create new instance self.__client.on_message = on_message #attach function to callback @@ -481,7 +481,7 @@ def main(): parser.add_argument('-m', '--mqtt-host', help="MQTT broker hostname", default='127.0.0.1') parser.add_argument('-p', '--mqtt-port', type=int, help="MQTT broker port number (default 1883)", default=1883) parser.add_argument('-P', '--plane-topic', dest='plane_topic', help="MQTT plane topic", default="skyscan/planes/json") - parser.add_argument('-T', '--tracking-topic', dest='tracking_topic', help="MQTT tracking topic", default="skyscan/tracking/json") + parser.add_argument('-T', '--flight-topic', dest='flight_topic', help="MQTT flight tracking topic", default="skyscan/flight/json") parser.add_argument('-v', '--verbose', action="store_true", help="Verbose output") args = parser.parse_args() @@ -515,7 +515,7 @@ def main(): logging.info("---[ Starting %s ]---------------------------------------------" % sys.argv[0]) - tracker = FlightTracker( args.mqtt_host, args.plane_topic, args.tracking_topic, mqtt_port = args.mqtt_port) + tracker = FlightTracker( args.mqtt_host, args.plane_topic, args.flight_topic, mqtt_port = args.mqtt_port) tracker.run() # Never returns