Updates
This commit is contained in:
71
rpi/training/utils/sort_labels.py
Normal file
71
rpi/training/utils/sort_labels.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def copy_images(src, dest):
|
||||
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
|
||||
|
||||
for filename in os.listdir(src):
|
||||
if any(filename.lower().endswith(ext) for ext in image_extensions):
|
||||
src_path = os.path.join(src, filename)
|
||||
dst_path = os.path.join(dest, filename)
|
||||
shutil.copy2(src_path, dst_path)
|
||||
|
||||
|
||||
def remap_labels(src, dest):
|
||||
count = 0
|
||||
for filename in os.listdir(src):
|
||||
if filename.endswith(".txt"):
|
||||
input_path = os.path.join(src, filename)
|
||||
output_path = os.path.join(dest, filename)
|
||||
|
||||
with open(input_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
old_idx = int(parts[0])
|
||||
new_idx = index_map[old_idx]
|
||||
new_lines.append(" ".join([str(new_idx)] + parts[1:]))
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
f.write("\n".join(new_lines))
|
||||
|
||||
if count%100 == 0:
|
||||
print(count)
|
||||
count += 1
|
||||
|
||||
print(f"All labels remapped and saved to '{dest}'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
src_dir = "../datasets/pieces/visualizan/"
|
||||
dest_dir = "../datasets/pieces/unified/"
|
||||
|
||||
reference_classes = [
|
||||
'w_pawn', 'w_knight', 'w_bishop', 'w_rook', 'w_queen', 'w_king',
|
||||
'b_pawn', 'b_knight', 'b_bishop', 'b_rook', 'b_queen', 'b_king'
|
||||
]
|
||||
|
||||
current_classes = ['b_bishop', 'b_king', 'b_knight', 'b_pawn', 'b_queen', 'b_rook',
|
||||
'w_bishop', 'w_king', 'w_knight', 'w_pawn', 'w_queen', 'w_rook']
|
||||
|
||||
index_map = {current_classes.index(cls): reference_classes.index(cls) for cls in current_classes}
|
||||
|
||||
sub_elements = os.listdir(src_dir)
|
||||
for sub in sub_elements:
|
||||
src_full_path = os.path.normpath(os.path.join(src_dir, sub))
|
||||
dest_full_path = os.path.normpath(os.path.join(dest_dir, sub))
|
||||
|
||||
if not os.path.isdir(src_full_path): continue
|
||||
|
||||
src_image_folder = os.path.normpath(os.path.join(src_full_path, "images"))
|
||||
src_labels_folder = os.path.normpath(os.path.join(src_full_path, "labels"))
|
||||
|
||||
dst_image_folder = os.path.normpath(os.path.join(dest_full_path, "images"))
|
||||
dst_labels_folder = os.path.normpath(os.path.join(dest_full_path, "labels"))
|
||||
|
||||
copy_images(src_image_folder, dst_image_folder)
|
||||
remap_labels(src_labels_folder, dst_labels_folder)
|
||||
Reference in New Issue
Block a user