Move inference onto the API
This commit is contained in:
@@ -5,6 +5,7 @@ from flask import Flask
|
||||
|
||||
from src.controllers.AuthController import AuthController
|
||||
from src.controllers.ClientController import ClientController
|
||||
from src.controllers.analyze_controller import AnalyzeController
|
||||
from src.controllers.message_controller import MessageController
|
||||
from src.controllers.mqtt_forwarder import create_forwarder
|
||||
from src.controllers.telemetryController import TelemetryController
|
||||
@@ -35,6 +36,7 @@ auth_controller = AuthController(app, auth_data, "https://192.168.15.120:8000")
|
||||
client_controller = ClientController(app, auth_data, "https://192.168.15.120:8000")
|
||||
message_controller = MessageController(app, auth_data, "https://192.168.15.120:8000", database_service)
|
||||
telemetry_controller = TelemetryController(app, database_service)
|
||||
analyze_controller = AnalyzeController(app)
|
||||
|
||||
def handle_login(data):
|
||||
local_broker, api_broker, forwarder = create_forwarder(data,
|
||||
|
||||
23
api-customer/src/controllers/analyze_controller.py
Normal file
23
api-customer/src/controllers/analyze_controller.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from flask import jsonify
|
||||
|
||||
from src.services.detection_service import DetectionService
|
||||
|
||||
|
||||
class AnalyzeController:
|
||||
|
||||
_detection_service : DetectionService
|
||||
|
||||
def __init__(self, app):
|
||||
self._register_routes(app)
|
||||
self._detection_service = DetectionService()
|
||||
|
||||
def _register_routes(self, app):
|
||||
app.add_url_rule("/analyze/image", view_func=self.analyze, methods=['POST'])
|
||||
|
||||
def analyze(self):
|
||||
try :
|
||||
fen = self._detection_service.analyze_single_frame(None)
|
||||
return jsonify({"success": False, "payload": {"fen" : fen}}), 200
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return jsonify({"success": False, "message": "Failed to analyze image"}), 500
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
from hardware.camera.camera import Camera
|
||||
from models.detection.detector import Detector
|
||||
from models.detection.board_manager import BoardManager
|
||||
from models.detection.pieces_manager import PiecesManager
|
||||
@@ -20,7 +19,6 @@ class DetectionService:
|
||||
|
||||
scale_size : tuple[int, int]
|
||||
|
||||
camera : Camera
|
||||
|
||||
def __init__(self):
|
||||
current_file = Path(__file__).resolve()
|
||||
@@ -32,18 +30,8 @@ class DetectionService:
|
||||
self.pieces_manager = PiecesManager()
|
||||
self.board_manager = BoardManager()
|
||||
self.scale_size = (800, 800)
|
||||
self.camera = Camera()
|
||||
|
||||
def start(self):
|
||||
self.camera.open()
|
||||
|
||||
def stop(self):
|
||||
self.camera.close()
|
||||
|
||||
def analyze_single_frame(self) -> tuple[bytes, str | None]:
|
||||
frame = self.camera.take_photo()
|
||||
encoded_frame = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])[1].tobytes()
|
||||
|
||||
def analyze_single_frame(self, frame : np.ndarray) -> str | None:
|
||||
result = self.__run_complete_detection(frame)
|
||||
|
||||
edges_prediction = result["edges"]
|
||||
@@ -51,7 +39,7 @@ class DetectionService:
|
||||
|
||||
processed_frame = self.board_manager.process_frame(edges_prediction[0], frame, self.scale_size)
|
||||
if processed_frame is None:
|
||||
return encoded_frame, None
|
||||
return None
|
||||
|
||||
warped_corners, matrix = processed_frame
|
||||
|
||||
@@ -59,7 +47,7 @@ class DetectionService:
|
||||
|
||||
board = self.pieces_manager.pieces_to_board(detections, warped_corners, matrix, self.scale_size)
|
||||
|
||||
return encoded_frame, self.pieces_manager.board_to_fen(board)
|
||||
return self.pieces_manager.board_to_fen(board)
|
||||
|
||||
def __run_complete_detection(self, frame : np.ndarray, display=False) -> dict[str, list[Results]] :
|
||||
pieces_prediction = self.__run_pieces_detection(frame)
|
||||
@@ -12,12 +12,14 @@ from services.mqtt_service import MQTTService
|
||||
class GameController:
|
||||
|
||||
_game_service : GameService
|
||||
_api_url : str
|
||||
_broker_service : MQTTService
|
||||
_has_started : bool
|
||||
_auth_token : str
|
||||
|
||||
def __init__(self, app : Flask, broker_service : MQTTService):
|
||||
def __init__(self, app : Flask, api_url : str, broker_service : MQTTService):
|
||||
self._game_service = GameService()
|
||||
self._api_url = api_url
|
||||
self._game_service.set_on_terminated(self._stop_event)
|
||||
self._broker_service = broker_service
|
||||
self._register_routes(app)
|
||||
@@ -61,11 +63,17 @@ class GameController:
|
||||
if auth_token != "Bearer " + self._auth_token:
|
||||
return jsonify({"status": "error", "message": "Invalid authorization token"}), 401
|
||||
|
||||
threading.Thread(
|
||||
target=self._analyze_move(),
|
||||
daemon=True
|
||||
).start()
|
||||
img = self._game_service.make_move()
|
||||
b64_img = base64.b64encode(img).decode('utf-8')
|
||||
payload = {
|
||||
"image": f"data:image/jpeg;base64,{b64_img}"
|
||||
}
|
||||
response = requests.post(self._api_url, json=payload, verify=False)
|
||||
print(response.status_code)
|
||||
|
||||
data = response.json()
|
||||
fen = data.get("fen")
|
||||
self._game_service.add_move(fen)
|
||||
return jsonify({"status": "ok"}), 200
|
||||
|
||||
except ServiceException as ex:
|
||||
@@ -74,22 +82,6 @@ class GameController:
|
||||
print(ex)
|
||||
return jsonify({"status": "error", "message": f"An error occurred : {ex}"}), 500
|
||||
|
||||
def _analyze_move(self):
|
||||
img, fen = self._game_service.make_move()
|
||||
self._send_detection_result("https://192.168.15.125:1880/party/image", img, fen)
|
||||
|
||||
def _send_detection_result(self, url, img, fen):
|
||||
try:
|
||||
b64_img = base64.b64encode(img).decode('utf-8')
|
||||
payload = {
|
||||
"fen": fen,
|
||||
"image": f"data:image/jpeg;base64,{b64_img}"
|
||||
}
|
||||
response = requests.post(url, json=payload, verify=False)
|
||||
print(response.status_code)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def _stop_event(self, game_data : str):
|
||||
try :
|
||||
print(f"Exporting game data : {game_data}")
|
||||
|
||||
@@ -23,13 +23,13 @@ class Camera:
|
||||
self.cap.release()
|
||||
self.cap = None
|
||||
|
||||
def take_photo(self) -> np.ndarray:
|
||||
def take_photo(self) -> bytes:
|
||||
self.open()
|
||||
try:
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
raise RuntimeError("Failed to capture image")
|
||||
return frame
|
||||
return cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])[1].tobytes()
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ api_broker = MQTTService(
|
||||
password=api_password,
|
||||
)
|
||||
|
||||
game_controller = GameController(app, api_broker)
|
||||
game_controller = GameController(app, "https://192.168.15.125:1880/party/image", api_broker)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
@@ -2,29 +2,29 @@ import json
|
||||
from typing import Callable
|
||||
|
||||
from hardware.buzzer.buzzer import Buzzer
|
||||
from hardware.camera.camera import Camera
|
||||
from hardware.led.led import Led
|
||||
from models.exceptions.ServiceException import ServiceException
|
||||
from models.game import Game
|
||||
from services.clock_service import ClockService
|
||||
from services.detection_service import DetectionService
|
||||
|
||||
|
||||
class GameService:
|
||||
|
||||
_game : Game
|
||||
_detection_service : DetectionService
|
||||
_camera : Camera
|
||||
_clock_service : ClockService
|
||||
_has_started : bool
|
||||
_led : Led
|
||||
_buzzer : Buzzer
|
||||
_on_terminated : Callable[[str], None]
|
||||
_has_started : bool
|
||||
|
||||
def __init__(self):
|
||||
self._detection_service = DetectionService()
|
||||
self._camera = Camera()
|
||||
self._clock_service = ClockService()
|
||||
self._has_started = False
|
||||
self._led = Led(7)
|
||||
self._buzzer = Buzzer(8)
|
||||
self._has_started = False
|
||||
|
||||
def start(self, white_name, back_name, time_control : int, increment : int, timestamp : int) -> None:
|
||||
if self._has_started :
|
||||
@@ -46,18 +46,20 @@ class GameService:
|
||||
self._notify()
|
||||
self._has_started = False
|
||||
|
||||
def make_move(self) -> tuple[bytes, str] | None:
|
||||
def make_move(self) -> bytes:
|
||||
try :
|
||||
if not self._has_started :
|
||||
raise Exception("Game hasn't started yet.")
|
||||
self._clock_service.switch()
|
||||
img, fen = self._detection_service.analyze_single_frame()
|
||||
self._game.add_move(fen)
|
||||
return img, fen
|
||||
img = self._camera.take_photo()
|
||||
return img
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ServiceException(e)
|
||||
|
||||
def add_move(self, fen):
|
||||
self._game.add_move(fen)
|
||||
|
||||
def set_on_terminated(self, callback: Callable[[str], None]):
|
||||
self._on_terminated = callback
|
||||
|
||||
|
||||
Reference in New Issue
Block a user