+from __future__ import annotations
+
+from typing import Dict, List, Optional
+import numpy as np
+
+class Node:
+ def __init__(self, parent: Optional[Node], idx: int):
+ self.idx: int = idx
+ self.parent: Node = parent
+ self.is_leaf: bool = False
+ self.r_sum: int = 0
+ self.g_sum: int = 0
+ self.b_sum: int = 0
+ self.childs: Dict[int, Optional[Node]] = {x: None for x in range(8)}
+ self.pixel_count: int = 0
+ self.palette_index: int = -1
+
+ def update_metadata(self, r:int, g:int, b:int, counts:int):
+ self.pixel_count += counts
+ self.r_sum += (r * counts)
+ self.g_sum += (g * counts)
+ self.b_sum += (b * counts)
+
+ def merge_child(self, child_idx: int):
+ child = self.childs[child_idx]
+ if child is not None:
+ self.pixel_count += child.pixel_count
+ self.r_sum += child.r_sum
+ self.g_sum += child.g_sum
+ self.b_sum += child.b_sum
+ self.childs[child_idx] = None
+ if all(x is None for x in self.childs.values()):
+ self.is_leaf = True
+
+class Octree:
+ def __init__(self, palette_size: int):
+ self.root: Node = Node(None, -1)
+ self.palette_size: int = palette_size
+ self.max_depth:int = 8
+ self.layer_map: List[List[Node]] = [[] for _ in range(8)]
+ self.leaf_count: int = 0
+
+ def prune(self):
+ prune_depth = self.max_depth - 1
+ while self.leaf_count > self.palette_size:
+ if len(self.layer_map[prune_depth]) == 0:
+ prune_depth -= 1
+ if prune_depth < 0:
+ break
+ continue
+ self.layer_map[prune_depth].sort(key=lambda x: x.pixel_count, reverse=True)
+ while len(self.layer_map[prune_depth]) > 0 and self.leaf_count > self.palette_size:
+ pnode = self.layer_map[prune_depth].pop()
+ if pnode.parent is None: # noqa
+ continue
+ parent = pnode.parent
+ self.leaf_count -= 1
+ parent.merge_child(pnode.idx)
+ if parent.is_leaf:
+ self.leaf_count += 1
+
+ def add_colors(self, unique_colors: np.ndarray, counts: np.ndarray):
+ # r: 155, g: 0, b: 100
+ for i, c in enumerate(unique_colors.tolist()):
+ r = format(int(c[0]), "08b")
+ g = format(int(c[1]), "08b")
+ b = format(int(c[2]), "08b")
+ # r: 10011011 , g: 00000000, b: 01100100
+ vals = [int(c[0]), int(c[1]), int(c[2])]
+ count = int(counts[i])
+
+ cur = self.root
+ for d in range(8):
+ bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+ child = cur.childs[bit_idx]
+ if child is None:
+ child = Node(cur, bit_idx)
+ cur.childs[bit_idx] = child
+ self.layer_map[d].append(child)
+ cur = child
+ cur.is_leaf = True
+ cur.update_metadata(*vals, counts=count)
+ self.leaf_count += 1
+ self.prune()
+ return self
+
+ def get_primary_colors(self) -> List[Tuple[np.ndarray, float]]:
+ palette: List[Tuple[np.ndarray, float]] = []
+ counter = 0
+ for layer in self.layer_map:
+ for node in layer:
+ if node.is_leaf:
+ node.palette_index = counter
+ if node.pixel_count > 0:
+ r = int(node.r_sum / node.pixel_count)
+ g = int(node.g_sum / node.pixel_count)
+ b = int(node.g_sum / node.pixel_count)
+ palette.append((
+ np.array([r, g, b]).astype(np.uint8),
+ node.pixel_count
+ ))
+ else:
+ palette.append((np.array([0, 0, 0]).astype(np.uint8), 0))
+ counter += 1
+ return palette
+
+ def get_palette_index(self, color_arr: np.ndarray) -> int:
+ r = format(int(color_arr[0]), "08b")
+ g = format(int(color_arr[1]), "08b")
+ b = format(int(color_arr[2]), "08b")
+ cur = self.root
+
+ for d in range(8):
+ bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+ child = cur.childs[bit_idx]
+ if child is None:
+ break
+ cur = child
+ if cur.is_leaf:
+ break
+ return getattr(cur, 'palette_index', -1)
+
+ def get_lut(self, raw_palette):
+ full_palette = np.zeros((256, 1, 3), dtype=np.uint8)
+ for i, color in enumerate(raw_palette):
+ if i >= 256:
+ break
+ full_palette[i, 0, 0] = color[0][0]
+ full_palette[i, 0, 1] = color[0][1]
+ full_palette[i, 0, 2] = color[0][2]
+ return full_palette