113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import cv2
|
|
import numpy as np
|
|
from typing import Any
|
|
|
|
from numpy import ndarray
|
|
|
|
|
|
class PiecesManager:
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def extract_pieces(self, pieces_pred) -> list[Any]:
|
|
result = pieces_pred[0]
|
|
detections = []
|
|
|
|
for box in result.boxes:
|
|
# xywh en pixels de l'image originale
|
|
x, y, w, h = box.xywh[0].cpu().numpy()
|
|
label = result.names[int(box.cls[0])]
|
|
detections.append({"label": label, "bbox": (int(x), int(y), int(w), int(h))})
|
|
|
|
return detections
|
|
|
|
def pieces_to_board(self, detected_boxes: list, warped_corners: ndarray, matrix: np.ndarray, board_size: tuple[int, int]) -> list[list[str | None]]:
|
|
|
|
board_array = [[None for _ in range(8)] for _ in range(8)]
|
|
board_width, board_height = board_size
|
|
|
|
tl, tr, br, bl = warped_corners
|
|
square_centers = self.__compute_square_centers(tl, tr, br, bl)
|
|
|
|
for d in detected_boxes:
|
|
x, y, w, h = d["bbox"]
|
|
|
|
points = np.array([
|
|
[x + w / 2, y + h * 0.2],
|
|
[x + w / 2, y + h / 2],
|
|
[x + w / 2, y + h * 0.8]
|
|
], dtype=np.float32).reshape(-1, 1, 2)
|
|
|
|
warped_points = cv2.perspectiveTransform(points, matrix)
|
|
|
|
wx = np.mean(warped_points[:, 0, 0])
|
|
wy = np.percentile(warped_points[:, 0, 1], 25)
|
|
|
|
best_rank = 0
|
|
best_file = 0
|
|
min_dist = float("inf")
|
|
|
|
for r, c, cx, cy in square_centers:
|
|
dist = (wx - cx) ** 2 + (wy - cy) ** 2
|
|
if dist < min_dist:
|
|
min_dist = dist
|
|
best_rank = r
|
|
best_file = c
|
|
|
|
max_reasonable_dist = (board_width / 8) ** 2
|
|
if min_dist > max_reasonable_dist:
|
|
continue
|
|
|
|
board_array[best_rank][best_file] = d["label"]
|
|
|
|
return board_array
|
|
|
|
def board_to_fen(self, board : list[list[str | None]]) -> str:
|
|
map_fen = {
|
|
"w_pawn": "P", "w_knight": "N", "w_bishop": "B",
|
|
"w_rook": "R", "w_queen": "Q", "w_king": "K",
|
|
"b_pawn": "p", "b_knight": "n", "b_bishop": "b",
|
|
"b_rook": "r", "b_queen": "q", "b_king": "k",
|
|
}
|
|
rows = []
|
|
for rank in board:
|
|
empty = 0
|
|
row = ""
|
|
for sq in rank:
|
|
if sq is None:
|
|
empty += 1
|
|
else:
|
|
if empty:
|
|
row += str(empty)
|
|
empty = 0
|
|
row += map_fen[sq]
|
|
if empty:
|
|
row += str(empty)
|
|
rows.append(row)
|
|
return "/".join(rows)
|
|
|
|
def __compute_square_centers(self, tl, tr, br, bl):
|
|
centers = []
|
|
for line in range(8):
|
|
for file in range(8):
|
|
u = (file + 0.5) / 8
|
|
v = (line + 0.5) / 8
|
|
|
|
# interpolation bilinéaire
|
|
x = (
|
|
(1 - u) * (1 - v) * tl[0] +
|
|
u * (1 - v) * tr[0] +
|
|
u * v * br[0] +
|
|
(1 - u) * v * bl[0]
|
|
)
|
|
y = (
|
|
(1 - u) * (1 - v) * tl[1] +
|
|
u * (1 - v) * tr[1] +
|
|
u * v * br[1] +
|
|
(1 - u) * v * bl[1]
|
|
)
|
|
|
|
centers.append((line, file, x, y))
|
|
return centers
|