Minor fixes
This commit is contained in:
99
rpi/services/detection_service.py
Normal file
99
rpi/services/detection_service.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
from models.detection.detector import Detector
|
||||
from models.detection.board_manager import BoardManager
|
||||
from models.detection.pieces_manager import PiecesManager
|
||||
|
||||
|
||||
class DetectionService:
|
||||
|
||||
edges_detector : Detector
|
||||
pieces_detector : Detector
|
||||
|
||||
board_manager : BoardManager
|
||||
pieces_manager : PiecesManager
|
||||
|
||||
scale_size : tuple[int, int]
|
||||
|
||||
def __init__(self):
|
||||
self.edges_detector = Detector("../assets/models/edges.pt")
|
||||
self.pieces_detector = Detector("../assets/models/unified-nano-refined.pt")
|
||||
|
||||
self.pieces_manager = PiecesManager()
|
||||
self.board_manager = BoardManager()
|
||||
self.scale_size = (800, 800)
|
||||
|
||||
|
||||
def run_complete_detection(self, frame : np.ndarray, display=False) -> dict[str, list[Results]] :
|
||||
pieces_prediction = self.run_pieces_detection(frame)
|
||||
edges_prediction = self.run_edges_detection(frame)
|
||||
|
||||
if display:
|
||||
edges_annotated_frame = edges_prediction[0].plot()
|
||||
pieces_annotated_frame = pieces_prediction[0].plot(img=edges_annotated_frame)
|
||||
self.__display_frame(pieces_annotated_frame)
|
||||
|
||||
return { "edges" : edges_prediction, "pieces" : pieces_prediction}
|
||||
|
||||
|
||||
def run_pieces_detection(self, frame : np.ndarray, display=False) -> list[Results]:
|
||||
prediction = self.pieces_detector.make_prediction(frame)
|
||||
if display:
|
||||
self.__display_frame(prediction[0].plot())
|
||||
return prediction
|
||||
|
||||
|
||||
def run_edges_detection(self, frame : np.ndarray, display=False) -> list[Results]:
|
||||
prediction = self.edges_detector.make_prediction(frame)
|
||||
if display:
|
||||
self.__display_frame(prediction[0].plot())
|
||||
return prediction
|
||||
|
||||
|
||||
def get_fen(self, frame : np.ndarray) -> str | None:
|
||||
result = self.run_complete_detection(frame)
|
||||
|
||||
edges_prediction = result["edges"]
|
||||
pieces_prediction = result["pieces"]
|
||||
|
||||
warped_corners, matrix = self.board_manager.process_frame(edges_prediction[0], frame, self.scale_size)
|
||||
if matrix is None:
|
||||
return None
|
||||
|
||||
detections = self.pieces_manager.extract_pieces(pieces_prediction)
|
||||
|
||||
board = self.pieces_manager.pieces_to_board(detections, warped_corners, matrix, self.scale_size)
|
||||
|
||||
return self.pieces_manager.board_to_fen(board)
|
||||
|
||||
|
||||
def __display_frame(self, frame : np.ndarray):
|
||||
cv2.namedWindow("Frame", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("Frame", self.scale_size[0], self.scale_size[1])
|
||||
cv2.imshow("Frame", frame)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__" :
|
||||
import os
|
||||
import random
|
||||
|
||||
service = DetectionService()
|
||||
|
||||
img_folder = "../training/datasets/pieces/unified/test/images/"
|
||||
|
||||
test_images = os.listdir(img_folder)
|
||||
|
||||
rnd = random.randint(0, len(test_images) - 1)
|
||||
img_path = os.path.join(img_folder, test_images[rnd])
|
||||
|
||||
image = cv2.imread(img_path)
|
||||
|
||||
fen = service.get_fen(image)
|
||||
print(fen)
|
||||
|
||||
service.run_complete_detection(image, display=True)
|
||||
Reference in New Issue
Block a user