Move inference onto the API

This commit is contained in:
2026-01-05 16:54:25 +01:00
parent e457fc6be8
commit 9e0d586f6a
15 changed files with 55 additions and 48 deletions

View File

@@ -5,6 +5,7 @@ from flask import Flask
from src.controllers.AuthController import AuthController from src.controllers.AuthController import AuthController
from src.controllers.ClientController import ClientController from src.controllers.ClientController import ClientController
from src.controllers.analyze_controller import AnalyzeController
from src.controllers.message_controller import MessageController from src.controllers.message_controller import MessageController
from src.controllers.mqtt_forwarder import create_forwarder from src.controllers.mqtt_forwarder import create_forwarder
from src.controllers.telemetryController import TelemetryController from src.controllers.telemetryController import TelemetryController
@@ -35,6 +36,7 @@ auth_controller = AuthController(app, auth_data, "https://192.168.15.120:8000")
client_controller = ClientController(app, auth_data, "https://192.168.15.120:8000") client_controller = ClientController(app, auth_data, "https://192.168.15.120:8000")
message_controller = MessageController(app, auth_data, "https://192.168.15.120:8000", database_service) message_controller = MessageController(app, auth_data, "https://192.168.15.120:8000", database_service)
telemetry_controller = TelemetryController(app, database_service) telemetry_controller = TelemetryController(app, database_service)
analyze_controller = AnalyzeController(app)
def handle_login(data): def handle_login(data):
local_broker, api_broker, forwarder = create_forwarder(data, local_broker, api_broker, forwarder = create_forwarder(data,

View File

@@ -0,0 +1,23 @@
from flask import jsonify
from src.services.detection_service import DetectionService
class AnalyzeController:
_detection_service : DetectionService
def __init__(self, app):
self._register_routes(app)
self._detection_service = DetectionService()
def _register_routes(self, app):
app.add_url_rule("/analyze/image", view_func=self.analyze, methods=['POST'])
def analyze(self):
try :
fen = self._detection_service.analyze_single_frame(None)
return jsonify({"success": False, "payload": {"fen" : fen}}), 200
except Exception as e:
print(e)
return jsonify({"success": False, "message": "Failed to analyze image"}), 500

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from ultralytics.engine.results import Results from ultralytics.engine.results import Results
from hardware.camera.camera import Camera
from models.detection.detector import Detector from models.detection.detector import Detector
from models.detection.board_manager import BoardManager from models.detection.board_manager import BoardManager
from models.detection.pieces_manager import PiecesManager from models.detection.pieces_manager import PiecesManager
@@ -20,7 +19,6 @@ class DetectionService:
scale_size : tuple[int, int] scale_size : tuple[int, int]
camera : Camera
def __init__(self): def __init__(self):
current_file = Path(__file__).resolve() current_file = Path(__file__).resolve()
@@ -32,18 +30,8 @@ class DetectionService:
self.pieces_manager = PiecesManager() self.pieces_manager = PiecesManager()
self.board_manager = BoardManager() self.board_manager = BoardManager()
self.scale_size = (800, 800) self.scale_size = (800, 800)
self.camera = Camera()
def start(self):
self.camera.open()
def stop(self):
self.camera.close()
def analyze_single_frame(self) -> tuple[bytes, str | None]:
frame = self.camera.take_photo()
encoded_frame = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])[1].tobytes()
def analyze_single_frame(self, frame : np.ndarray) -> str | None:
result = self.__run_complete_detection(frame) result = self.__run_complete_detection(frame)
edges_prediction = result["edges"] edges_prediction = result["edges"]
@@ -51,7 +39,7 @@ class DetectionService:
processed_frame = self.board_manager.process_frame(edges_prediction[0], frame, self.scale_size) processed_frame = self.board_manager.process_frame(edges_prediction[0], frame, self.scale_size)
if processed_frame is None: if processed_frame is None:
return encoded_frame, None return None
warped_corners, matrix = processed_frame warped_corners, matrix = processed_frame
@@ -59,7 +47,7 @@ class DetectionService:
board = self.pieces_manager.pieces_to_board(detections, warped_corners, matrix, self.scale_size) board = self.pieces_manager.pieces_to_board(detections, warped_corners, matrix, self.scale_size)
return encoded_frame, self.pieces_manager.board_to_fen(board) return self.pieces_manager.board_to_fen(board)
def __run_complete_detection(self, frame : np.ndarray, display=False) -> dict[str, list[Results]] : def __run_complete_detection(self, frame : np.ndarray, display=False) -> dict[str, list[Results]] :
pieces_prediction = self.__run_pieces_detection(frame) pieces_prediction = self.__run_pieces_detection(frame)

View File

@@ -12,12 +12,14 @@ from services.mqtt_service import MQTTService
class GameController: class GameController:
_game_service : GameService _game_service : GameService
_api_url : str
_broker_service : MQTTService _broker_service : MQTTService
_has_started : bool _has_started : bool
_auth_token : str _auth_token : str
def __init__(self, app : Flask, broker_service : MQTTService): def __init__(self, app : Flask, api_url : str, broker_service : MQTTService):
self._game_service = GameService() self._game_service = GameService()
self._api_url = api_url
self._game_service.set_on_terminated(self._stop_event) self._game_service.set_on_terminated(self._stop_event)
self._broker_service = broker_service self._broker_service = broker_service
self._register_routes(app) self._register_routes(app)
@@ -61,11 +63,17 @@ class GameController:
if auth_token != "Bearer " + self._auth_token: if auth_token != "Bearer " + self._auth_token:
return jsonify({"status": "error", "message": "Invalid authorization token"}), 401 return jsonify({"status": "error", "message": "Invalid authorization token"}), 401
threading.Thread( img = self._game_service.make_move()
target=self._analyze_move(), b64_img = base64.b64encode(img).decode('utf-8')
daemon=True payload = {
).start() "image": f"data:image/jpeg;base64,{b64_img}"
}
response = requests.post(self._api_url, json=payload, verify=False)
print(response.status_code)
data = response.json()
fen = data.get("fen")
self._game_service.add_move(fen)
return jsonify({"status": "ok"}), 200 return jsonify({"status": "ok"}), 200
except ServiceException as ex: except ServiceException as ex:
@@ -74,22 +82,6 @@ class GameController:
print(ex) print(ex)
return jsonify({"status": "error", "message": f"An error occurred : {ex}"}), 500 return jsonify({"status": "error", "message": f"An error occurred : {ex}"}), 500
def _analyze_move(self):
img, fen = self._game_service.make_move()
self._send_detection_result("https://192.168.15.125:1880/party/image", img, fen)
def _send_detection_result(self, url, img, fen):
try:
b64_img = base64.b64encode(img).decode('utf-8')
payload = {
"fen": fen,
"image": f"data:image/jpeg;base64,{b64_img}"
}
response = requests.post(url, json=payload, verify=False)
print(response.status_code)
except Exception as e:
print(e)
def _stop_event(self, game_data : str): def _stop_event(self, game_data : str):
try : try :
print(f"Exporting game data : {game_data}") print(f"Exporting game data : {game_data}")

View File

@@ -23,13 +23,13 @@ class Camera:
self.cap.release() self.cap.release()
self.cap = None self.cap = None
def take_photo(self) -> np.ndarray: def take_photo(self) -> bytes:
self.open() self.open()
try: try:
ret, frame = self.cap.read() ret, frame = self.cap.read()
if not ret: if not ret:
raise RuntimeError("Failed to capture image") raise RuntimeError("Failed to capture image")
return frame return cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])[1].tobytes()
finally: finally:
self.close() self.close()

View File

@@ -44,7 +44,7 @@ api_broker = MQTTService(
password=api_password, password=api_password,
) )
game_controller = GameController(app, api_broker) game_controller = GameController(app, "https://192.168.15.125:1880/party/image", api_broker)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -2,29 +2,29 @@ import json
from typing import Callable from typing import Callable
from hardware.buzzer.buzzer import Buzzer from hardware.buzzer.buzzer import Buzzer
from hardware.camera.camera import Camera
from hardware.led.led import Led from hardware.led.led import Led
from models.exceptions.ServiceException import ServiceException from models.exceptions.ServiceException import ServiceException
from models.game import Game from models.game import Game
from services.clock_service import ClockService from services.clock_service import ClockService
from services.detection_service import DetectionService
class GameService: class GameService:
_game : Game _game : Game
_detection_service : DetectionService _camera : Camera
_clock_service : ClockService _clock_service : ClockService
_has_started : bool
_led : Led _led : Led
_buzzer : Buzzer _buzzer : Buzzer
_on_terminated : Callable[[str], None] _on_terminated : Callable[[str], None]
_has_started : bool
def __init__(self): def __init__(self):
self._detection_service = DetectionService() self._camera = Camera()
self._clock_service = ClockService() self._clock_service = ClockService()
self._has_started = False
self._led = Led(7) self._led = Led(7)
self._buzzer = Buzzer(8) self._buzzer = Buzzer(8)
self._has_started = False
def start(self, white_name, back_name, time_control : int, increment : int, timestamp : int) -> None: def start(self, white_name, back_name, time_control : int, increment : int, timestamp : int) -> None:
if self._has_started : if self._has_started :
@@ -46,18 +46,20 @@ class GameService:
self._notify() self._notify()
self._has_started = False self._has_started = False
def make_move(self) -> tuple[bytes, str] | None: def make_move(self) -> bytes:
try : try :
if not self._has_started : if not self._has_started :
raise Exception("Game hasn't started yet.") raise Exception("Game hasn't started yet.")
self._clock_service.switch() self._clock_service.switch()
img, fen = self._detection_service.analyze_single_frame() img = self._camera.take_photo()
self._game.add_move(fen) return img
return img, fen
except Exception as e: except Exception as e:
print(e) print(e)
raise ServiceException(e) raise ServiceException(e)
def add_move(self, fen):
self._game.add_move(fen)
def set_on_terminated(self, callback: Callable[[str], None]): def set_on_terminated(self, callback: Callable[[str], None]):
self._on_terminated = callback self._on_terminated = callback