93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import random
|
|
|
|
from ultralytics import YOLO
|
|
import cv2
|
|
|
|
# Map class names to FEN characters
|
|
class_to_fen = {
|
|
'w_pawn': 'P',
|
|
'w_knight': 'N',
|
|
'w_bishop': 'B',
|
|
'w_rook': 'R',
|
|
'w_queen': 'Q',
|
|
'w_king': 'K',
|
|
'b_pawn': 'p',
|
|
'b_knight': 'n',
|
|
'b_bishop': 'b',
|
|
'b_rook': 'r',
|
|
'b_queen': 'q',
|
|
'b_king': 'k',
|
|
}
|
|
|
|
def prediction_to_fen(results, width, height):
|
|
|
|
# Initialize empty board
|
|
board = [['' for _ in range(8)] for _ in range(8)]
|
|
|
|
# Iterate through predictions
|
|
for result in results:
|
|
for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
|
|
x1, y1, x2, y2 = box.tolist()
|
|
class_name = model.names[int(cls)]
|
|
fen_char = class_to_fen.get(class_name)
|
|
|
|
if fen_char:
|
|
# Compute board square
|
|
col = int((x1 + x2) / 2 / (width / 8))
|
|
row = 7 - int((y1 + y2) / 2 / (height / 8))
|
|
board[row][col] = fen_char
|
|
print(f"[{class_name}] {fen_char} {row} {col}")
|
|
|
|
# Convert board to FEN
|
|
fen_rows = []
|
|
for row in board:
|
|
fen_row = ''
|
|
empty_count = 0
|
|
for square in row:
|
|
if square == '':
|
|
empty_count += 1
|
|
else:
|
|
if empty_count > 0:
|
|
fen_row += str(empty_count)
|
|
empty_count = 0
|
|
fen_row += square
|
|
if empty_count > 0:
|
|
fen_row += str(empty_count)
|
|
fen_rows.append(fen_row)
|
|
|
|
# Join rows into a FEN string (default: white to move, all castling rights, no en passant)
|
|
fen_string = '/'.join(fen_rows) + ' w KQkq - 0 1'
|
|
return fen_string
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_path = "../assets/models/unified-nano-refined.pt"
|
|
img_folder = "../training/datasets/pieces/unified/test/images/"
|
|
save_folder = "./results"
|
|
os.makedirs(save_folder, exist_ok=True)
|
|
|
|
test_images = os.listdir(img_folder)
|
|
|
|
for i in range(0, 10):
|
|
rnd = random.randint(0, len(test_images) - 1)
|
|
img_path = os.path.join(img_folder, test_images[rnd])
|
|
save_path = os.path.join(save_folder, test_images[rnd])
|
|
|
|
img = cv2.imread(img_path)
|
|
height, width = img.shape[:2]
|
|
|
|
model = YOLO(model_path)
|
|
results = model.predict(source=img_path, conf=0.5)
|
|
|
|
#fen = prediction_to_fen(results, height, width)
|
|
#print("Predicted FEN:", fen)
|
|
|
|
annotated_image = results[0].plot()
|
|
cv2.imwrite(save_path, annotated_image)
|
|
#cv2.namedWindow("YOLO Predictions", cv2.WINDOW_NORMAL)
|
|
#cv2.imshow("YOLO Predictions", annotated_image)
|
|
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows() |