#!/usr/bin/env python3 import os import random from ultralytics import YOLO import cv2 # Map class names to FEN characters class_to_fen = { 'w_pawn': 'P', 'w_knight': 'N', 'w_bishop': 'B', 'w_rook': 'R', 'w_queen': 'Q', 'w_king': 'K', 'b_pawn': 'p', 'b_knight': 'n', 'b_bishop': 'b', 'b_rook': 'r', 'b_queen': 'q', 'b_king': 'k', } def prediction_to_fen(results, width, height): # Initialize empty board board = [['' for _ in range(8)] for _ in range(8)] # Iterate through predictions for result in results: for box, cls in zip(result.boxes.xyxy, result.boxes.cls): x1, y1, x2, y2 = box.tolist() class_name = model.names[int(cls)] fen_char = class_to_fen.get(class_name) if fen_char: # Compute board square col = int((x1 + x2) / 2 / (width / 8)) row = 7 - int((y1 + y2) / 2 / (height / 8)) board[row][col] = fen_char print(f"[{class_name}] {fen_char} {row} {col}") # Convert board to FEN fen_rows = [] for row in board: fen_row = '' empty_count = 0 for square in row: if square == '': empty_count += 1 else: if empty_count > 0: fen_row += str(empty_count) empty_count = 0 fen_row += square if empty_count > 0: fen_row += str(empty_count) fen_rows.append(fen_row) # Join rows into a FEN string (default: white to move, all castling rights, no en passant) fen_string = '/'.join(fen_rows) + ' w KQkq - 0 1' return fen_string if __name__ == "__main__": model_path = "../assets/models/unified-nano-refined.pt" img_folder = "../training/datasets/pieces/unified/test/images/" save_folder = "./results" os.makedirs(save_folder, exist_ok=True) test_images = os.listdir(img_folder) for i in range(0, 10): rnd = random.randint(0, len(test_images) - 1) img_path = os.path.join(img_folder, test_images[rnd]) save_path = os.path.join(save_folder, test_images[rnd]) img = cv2.imread(img_path) height, width = img.shape[:2] model = YOLO(model_path) results = model.predict(source=img_path, conf=0.5) #fen = prediction_to_fen(results, height, width) #print("Predicted FEN:", fen) annotated_image = results[0].plot() cv2.imwrite(save_path, annotated_image) #cv2.namedWindow("YOLO Predictions", cv2.WINDOW_NORMAL) #cv2.imshow("YOLO Predictions", annotated_image) cv2.waitKey(0) cv2.destroyAllWindows()