Files
board-mate/api-customer/src/models/detection/pieces_manager.py
2026-01-05 16:54:25 +01:00

113 lines
3.4 KiB
Python

import cv2
import numpy as np
from typing import Any
from numpy import ndarray
class PiecesManager:
def __init__(self):
pass
def extract_pieces(self, pieces_pred) -> list[Any]:
result = pieces_pred[0]
detections = []
for box in result.boxes:
# xywh en pixels de l'image originale
x, y, w, h = box.xywh[0].cpu().numpy()
label = result.names[int(box.cls[0])]
detections.append({"label": label, "bbox": (int(x), int(y), int(w), int(h))})
return detections
def pieces_to_board(self, detected_boxes: list, warped_corners: ndarray, matrix: np.ndarray, board_size: tuple[int, int]) -> list[list[str | None]]:
board_array = [[None for _ in range(8)] for _ in range(8)]
board_width, board_height = board_size
tl, tr, br, bl = warped_corners
square_centers = self.__compute_square_centers(tl, tr, br, bl)
for d in detected_boxes:
x, y, w, h = d["bbox"]
points = np.array([
[x + w / 2, y + h * 0.2],
[x + w / 2, y + h / 2],
[x + w / 2, y + h * 0.8]
], dtype=np.float32).reshape(-1, 1, 2)
warped_points = cv2.perspectiveTransform(points, matrix)
wx = np.mean(warped_points[:, 0, 0])
wy = np.percentile(warped_points[:, 0, 1], 25)
best_rank = 0
best_file = 0
min_dist = float("inf")
for r, c, cx, cy in square_centers:
dist = (wx - cx) ** 2 + (wy - cy) ** 2
if dist < min_dist:
min_dist = dist
best_rank = r
best_file = c
max_reasonable_dist = (board_width / 8) ** 2
if min_dist > max_reasonable_dist:
continue
board_array[best_rank][best_file] = d["label"]
return board_array
def board_to_fen(self, board : list[list[str | None]]) -> str:
map_fen = {
"w_pawn": "P", "w_knight": "N", "w_bishop": "B",
"w_rook": "R", "w_queen": "Q", "w_king": "K",
"b_pawn": "p", "b_knight": "n", "b_bishop": "b",
"b_rook": "r", "b_queen": "q", "b_king": "k",
}
rows = []
for rank in board:
empty = 0
row = ""
for sq in rank:
if sq is None:
empty += 1
else:
if empty:
row += str(empty)
empty = 0
row += map_fen[sq]
if empty:
row += str(empty)
rows.append(row)
return "/".join(rows)
def __compute_square_centers(self, tl, tr, br, bl):
centers = []
for line in range(8):
for file in range(8):
u = (file + 0.5) / 8
v = (line + 0.5) / 8
# interpolation bilinéaire
x = (
(1 - u) * (1 - v) * tl[0] +
u * (1 - v) * tr[0] +
u * v * br[0] +
(1 - u) * v * bl[0]
)
y = (
(1 - u) * (1 - v) * tl[1] +
u * (1 - v) * tr[1] +
u * v * br[1] +
(1 - u) * v * bl[1]
)
centers.append((line, file, x, y))
return centers