import cv2 import numpy as np from typing import Any from numpy import ndarray class PiecesManager: def __init__(self): pass def extract_pieces(self, pieces_pred) -> list[Any]: result = pieces_pred[0] detections = [] for box in result.boxes: # xywh en pixels de l'image originale x, y, w, h = box.xywh[0].cpu().numpy() label = result.names[int(box.cls[0])] detections.append({"label": label, "bbox": (int(x), int(y), int(w), int(h))}) return detections def pieces_to_board(self, detected_boxes: list, warped_corners: ndarray, matrix: np.ndarray, board_size: tuple[int, int]) -> list[list[str | None]]: board_array = [[None for _ in range(8)] for _ in range(8)] board_width, board_height = board_size tl, tr, br, bl = warped_corners square_centers = self.__compute_square_centers(tl, tr, br, bl) for d in detected_boxes: x, y, w, h = d["bbox"] points = np.array([ [x + w / 2, y + h * 0.2], [x + w / 2, y + h / 2], [x + w / 2, y + h * 0.8] ], dtype=np.float32).reshape(-1, 1, 2) warped_points = cv2.perspectiveTransform(points, matrix) wx = np.mean(warped_points[:, 0, 0]) wy = np.percentile(warped_points[:, 0, 1], 25) best_rank = 0 best_file = 0 min_dist = float("inf") for r, c, cx, cy in square_centers: dist = (wx - cx) ** 2 + (wy - cy) ** 2 if dist < min_dist: min_dist = dist best_rank = r best_file = c max_reasonable_dist = (board_width / 8) ** 2 if min_dist > max_reasonable_dist: continue board_array[best_rank][best_file] = d["label"] return board_array def board_to_fen(self, board : list[list[str | None]]) -> str: map_fen = { "w_pawn": "P", "w_knight": "N", "w_bishop": "B", "w_rook": "R", "w_queen": "Q", "w_king": "K", "b_pawn": "p", "b_knight": "n", "b_bishop": "b", "b_rook": "r", "b_queen": "q", "b_king": "k", } rows = [] for rank in board: empty = 0 row = "" for sq in rank: if sq is None: empty += 1 else: if empty: row += str(empty) empty = 0 row += map_fen[sq] if empty: row += str(empty) rows.append(row) return "/".join(rows) def __compute_square_centers(self, tl, tr, br, bl): centers = [] for line in range(8): for file in range(8): u = (file + 0.5) / 8 v = (line + 0.5) / 8 # interpolation bilinéaire x = ( (1 - u) * (1 - v) * tl[0] + u * (1 - v) * tr[0] + u * v * br[0] + (1 - u) * v * bl[0] ) y = ( (1 - u) * (1 - v) * tl[1] + u * (1 - v) * tr[1] + u * v * br[1] + (1 - u) * v * bl[1] ) centers.append((line, file, x, y)) return centers