Add conversion to FEN
This commit is contained in:
@@ -1,93 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
from detector import Detector
|
||||
from board_manager import BoardManager
|
||||
|
||||
from ultralytics import YOLO
|
||||
# -------------------- Pièces --------------------
|
||||
def extract_pieces(pieces_pred):
|
||||
"""Extrait les pièces avec leur bbox, sans remapping inutile"""
|
||||
result = pieces_pred[0]
|
||||
detections = []
|
||||
|
||||
for box in result.boxes:
|
||||
# xywh en pixels de l'image originale
|
||||
x, y, w, h = box.xywh[0].cpu().numpy()
|
||||
label = result.names[int(box.cls[0])]
|
||||
detections.append({"label": label, "bbox": (int(x), int(y), int(w), int(h))})
|
||||
|
||||
return detections
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
# Map class names to FEN characters
|
||||
class_to_fen = {
|
||||
'w_pawn': 'P',
|
||||
'w_knight': 'N',
|
||||
'w_bishop': 'B',
|
||||
'w_rook': 'R',
|
||||
'w_queen': 'Q',
|
||||
'w_king': 'K',
|
||||
'b_pawn': 'p',
|
||||
'b_knight': 'n',
|
||||
'b_bishop': 'b',
|
||||
'b_rook': 'r',
|
||||
'b_queen': 'q',
|
||||
'b_king': 'k',
|
||||
}
|
||||
def pieces_to_board(detected_boxes, matrix, board_size=800):
|
||||
board_array = [[None for _ in range(8)] for _ in range(8)]
|
||||
|
||||
def prediction_to_fen(results, width, height):
|
||||
for d in detected_boxes:
|
||||
x, y, w, h = d["bbox"]
|
||||
|
||||
# Initialize empty board
|
||||
board = [['' for _ in range(8)] for _ in range(8)]
|
||||
# Points multiples sur la pièce pour stabilité
|
||||
points = np.array([
|
||||
[x + w/2, y + h*0.2], # haut
|
||||
[x + w/2, y + h/2], # centre
|
||||
[x + w/2, y + h*0.8] # bas
|
||||
], dtype=np.float32).reshape(-1,1,2)
|
||||
|
||||
# Iterate through predictions
|
||||
for result in results:
|
||||
for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
|
||||
x1, y1, x2, y2 = box.tolist()
|
||||
class_name = model.names[int(cls)]
|
||||
fen_char = class_to_fen.get(class_name)
|
||||
# Transformation perspective
|
||||
warped_points = cv2.perspectiveTransform(points, matrix)
|
||||
wy_values = warped_points[:,0,1] # coordonnées y après warp
|
||||
|
||||
if fen_char:
|
||||
# Compute board square
|
||||
col = int((x1 + x2) / 2 / (width / 8))
|
||||
row = 7 - int((y1 + y2) / 2 / (height / 8))
|
||||
board[row][col] = fen_char
|
||||
print(f"[{class_name}] {fen_char} {row} {col}")
|
||||
# Prendre le percentile haut (25%) pour éviter décalage
|
||||
wy_percentile = np.percentile(wy_values, 25)
|
||||
|
||||
# Convert board to FEN
|
||||
fen_rows = []
|
||||
for row in board:
|
||||
fen_row = ''
|
||||
empty_count = 0
|
||||
for square in row:
|
||||
if square == '':
|
||||
empty_count += 1
|
||||
# Normaliser et calculer rank/file
|
||||
nx = np.clip(np.mean(warped_points[:,0,0]) / board_size, 0, 0.999)
|
||||
ny = np.clip(wy_percentile / board_size, 0, 0.999)
|
||||
|
||||
file = min(max(int(nx * 8), 0), 7)
|
||||
rank = min(max(int(ny * 8), 0), 7)
|
||||
|
||||
board_array[rank][file] = d["label"]
|
||||
|
||||
return board_array
|
||||
|
||||
|
||||
def board_to_fen(board):
|
||||
map_fen = {
|
||||
"w_pawn": "P", "w_knight": "N", "w_bishop": "B",
|
||||
"w_rook": "R", "w_queen": "Q", "w_king": "K",
|
||||
"b_pawn": "p", "b_knight": "n", "b_bishop": "b",
|
||||
"b_rook": "r", "b_queen": "q", "b_king": "k",
|
||||
}
|
||||
rows = []
|
||||
for rank in board:
|
||||
empty = 0
|
||||
row = ""
|
||||
for sq in rank:
|
||||
if sq is None:
|
||||
empty += 1
|
||||
else:
|
||||
if empty_count > 0:
|
||||
fen_row += str(empty_count)
|
||||
empty_count = 0
|
||||
fen_row += square
|
||||
if empty_count > 0:
|
||||
fen_row += str(empty_count)
|
||||
fen_rows.append(fen_row)
|
||||
|
||||
# Join rows into a FEN string (default: white to move, all castling rights, no en passant)
|
||||
fen_string = '/'.join(fen_rows) + ' w KQkq - 0 1'
|
||||
return fen_string
|
||||
|
||||
if empty:
|
||||
row += str(empty)
|
||||
empty = 0
|
||||
row += map_fen[sq]
|
||||
if empty:
|
||||
row += str(empty)
|
||||
rows.append(row)
|
||||
return "/".join(rows)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = "../assets/models/unified-nano-refined.pt"
|
||||
img_folder = "../training/datasets/pieces/unified/test/images/"
|
||||
save_folder = "./results"
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
edges_detector = Detector("../assets/models/edges.pt")
|
||||
pieces_detector = Detector("../assets/models/unified-nano-refined.pt")
|
||||
#image_path = "./test/1.png"
|
||||
image_path = "../training/datasets/pieces/unified/test/images/659_jpg.rf.0009cadea8df487a76d6960a28b9d811.jpg"
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
test_images = os.listdir(img_folder)
|
||||
edges_pred = edges_detector.make_prediction(image_path)
|
||||
pieces_pred = pieces_detector.make_prediction(image_path)
|
||||
|
||||
for i in range(0, 10):
|
||||
rnd = random.randint(0, len(test_images) - 1)
|
||||
img_path = os.path.join(img_folder, test_images[rnd])
|
||||
save_path = os.path.join(save_folder, test_images[rnd])
|
||||
remap_width = 800
|
||||
remap_height = 800
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
height, width = img.shape[:2]
|
||||
board_manager = BoardManager(image)
|
||||
corners, matrix = board_manager.extract_corners(edges_pred[0], (remap_width, remap_height))
|
||||
|
||||
model = YOLO(model_path)
|
||||
results = model.predict(source=img_path, conf=0.5)
|
||||
detections = extract_pieces(pieces_pred)
|
||||
|
||||
#fen = prediction_to_fen(results, height, width)
|
||||
#print("Predicted FEN:", fen)
|
||||
board = pieces_to_board(detections, matrix, remap_width)
|
||||
|
||||
annotated_image = results[0].plot()
|
||||
cv2.imwrite(save_path, annotated_image)
|
||||
#cv2.namedWindow("YOLO Predictions", cv2.WINDOW_NORMAL)
|
||||
#cv2.imshow("YOLO Predictions", annotated_image)
|
||||
# FEN
|
||||
fen = board_to_fen(board)
|
||||
print("FEN:", fen)
|
||||
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
frame = pieces_pred[0].plot()
|
||||
cv2.namedWindow("Pred", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("Pred", frame)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
Reference in New Issue
Block a user