align median cut main
authormushcatshiro <yingjieho@hotmail.com>
Sun, 23 Nov 2025 14:19:59 +0000 (22:19 +0800)
committermushcatshiro <yingjieho@hotmail.com>
Sun, 23 Nov 2025 14:19:59 +0000 (22:19 +0800)
lumos/feature/octree.py [new file with mode: 0644]
lumos/feature/principal_color.py

diff --git a/lumos/feature/octree.py b/lumos/feature/octree.py
new file mode 100644 (file)
index 0000000..10221dd
--- /dev/null
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+from typing import Dict, List, Optional
+import numpy as np
+
+class Node:
+  def __init__(self, parent: Optional[Node], idx: int):
+    self.idx: int = idx
+    self.parent: Node = parent
+    self.is_leaf: bool = False
+    self.r_sum: int = 0
+    self.g_sum: int = 0
+    self.b_sum: int = 0
+    self.childs: Dict[int, Optional[Node]] = {x: None for x in range(8)}
+    self.pixel_count: int = 0
+    self.palette_index: int = -1
+
+  def update_metadata(self, r:int, g:int, b:int, counts:int):
+    self.pixel_count += counts
+    self.r_sum += (r * counts)
+    self.g_sum += (g * counts)
+    self.b_sum += (b * counts)
+
+  def merge_child(self, child_idx: int):
+    child = self.childs[child_idx]
+    if child is not None:
+      self.pixel_count += child.pixel_count
+      self.r_sum += child.r_sum
+      self.g_sum += child.g_sum
+      self.b_sum += child.b_sum
+      self.childs[child_idx] = None
+      if all(x is None for x in self.childs.values()):
+        self.is_leaf = True
+
+class Octree:
+  def __init__(self, palette_size: int):
+    self.root: Node = Node(None, -1)
+    self.palette_size: int = palette_size
+    self.max_depth:int = 8
+    self.layer_map: List[List[Node]] = [[] for _ in range(8)]
+    self.leaf_count: int = 0
+
+  def prune(self):
+    prune_depth = self.max_depth - 1
+    while self.leaf_count > self.palette_size:
+      if len(self.layer_map[prune_depth]) == 0:
+        prune_depth -= 1
+        if prune_depth < 0:
+          break
+        continue
+      self.layer_map[prune_depth].sort(key=lambda x: x.pixel_count, reverse=True)
+      while len(self.layer_map[prune_depth]) > 0 and self.leaf_count > self.palette_size:
+        pnode = self.layer_map[prune_depth].pop()
+        if pnode.parent is None:  # noqa
+          continue
+        parent = pnode.parent
+        self.leaf_count -= 1
+        parent.merge_child(pnode.idx)
+        if parent.is_leaf:
+          self.leaf_count += 1
+
+  def add_colors(self, unique_colors: np.ndarray, counts: np.ndarray):
+    # r: 155, g: 0,  b: 100
+    for i, c in enumerate(unique_colors.tolist()):
+      r = format(int(c[0]), "08b")
+      g = format(int(c[1]), "08b")
+      b = format(int(c[2]), "08b")
+      # r: 10011011 , g: 00000000, b: 01100100
+      vals = [int(c[0]), int(c[1]), int(c[2])]
+      count = int(counts[i])
+
+      cur = self.root
+      for d in range(8):
+        bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+        child = cur.childs[bit_idx]
+        if child is None:
+          child = Node(cur, bit_idx)
+          cur.childs[bit_idx] = child
+          self.layer_map[d].append(child)
+        cur = child
+      cur.is_leaf = True
+      cur.update_metadata(*vals, counts=count)
+      self.leaf_count += 1
+    self.prune()
+    return self
+
+  def get_primary_colors(self) -> List[Tuple[np.ndarray, float]]:
+    palette: List[Tuple[np.ndarray, float]] = []
+    counter = 0
+    for layer in self.layer_map:
+      for node in layer:
+        if node.is_leaf:
+          node.palette_index = counter
+          if node.pixel_count > 0:
+            r = int(node.r_sum / node.pixel_count)
+            g = int(node.g_sum / node.pixel_count)
+            b = int(node.g_sum / node.pixel_count)
+            palette.append((
+              np.array([r, g, b]).astype(np.uint8),
+              node.pixel_count
+            ))
+          else:
+            palette.append((np.array([0, 0, 0]).astype(np.uint8), 0))
+          counter += 1
+    return palette
+
+  def get_palette_index(self, color_arr: np.ndarray) -> int:
+    r = format(int(color_arr[0]), "08b")
+    g = format(int(color_arr[1]), "08b")
+    b = format(int(color_arr[2]), "08b")
+    cur = self.root
+
+    for d in range(8):
+      bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+      child = cur.childs[bit_idx]
+      if child is None:
+        break
+      cur = child
+      if cur.is_leaf:
+        break
+    return getattr(cur, 'palette_index', -1)
+
+  def get_lut(self, raw_palette):
+    full_palette = np.zeros((256, 1, 3), dtype=np.uint8)
+    for i, color in enumerate(raw_palette):
+      if i >= 256:
+        break
+      full_palette[i, 0, 0] = color[0][0]
+      full_palette[i, 0, 1] = color[0][1]
+      full_palette[i, 0, 2] = color[0][2]
+    return full_palette
index 4e1e5ded9f19e06f81832265de24f9733ede6443..9c78b1630146061f365adb44dd057bbabd69af05 100644 (file)
@@ -72,34 +72,48 @@ def median_cut_extractor(n_centroids):
     boxes = [img_f]
     while len(boxes) < n_centroids:
       largest_box_idx = np.argmax([box.shape[0] for box in boxes])
     boxes = [img_f]
     while len(boxes) < n_centroids:
       largest_box_idx = np.argmax([box.shape[0] for box in boxes])
-      c = boxes.pop(largest_box_idx)  # h*w, c
-      max_vals = np.max(c, axis=0)
-      min_vals = np.min(c, axis=0)
+      box = boxes.pop(largest_box_idx)  # h*w, c
+      if box.shape[0] == 0:
+        continue # skip empty box
+      max_vals = np.max(box, axis=0)
+      min_vals = np.min(box, axis=0)
       ranges = max_vals - min_vals
       max_chan = np.argmax(ranges)  # (1, )
       ranges = max_vals - min_vals
       max_chan = np.argmax(ranges)  # (1, )
-      median_val = np.median(c[:, max_chan])
+      median_val = np.median(box[:, max_chan])
       # no broadcast like op, c[:, max_chan] > median val results in row-wise
       # boolean mask [T, F, ...] of shape (h*w)
       # no broadcast like op, c[:, max_chan] > median val results in row-wise
       # boolean mask [T, F, ...] of shape (h*w)
-      l = c[c[:, max_chan] <= median_val]
-      r = c[c[:, max_chan] > median_val]
+      l = box[box[:, max_chan] <= median_val]
+      r = box[box[:, max_chan] > median_val]
       boxes.append(l)
       boxes.append(r)
       boxes.append(l)
       boxes.append(r)
-    ret = []
+      if len(boxes) > n_centroids * 2:
+        break
+    palette = []
     for box in boxes:
     for box in boxes:
-      box: np.ndarray
-      box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
-      ret.append([box_mean, len(box)])
-
-    centroids = np.asarray([x[0] for x in ret])
-    lut_size = 256
-    r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size))
-    color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)
-    tree = cKDTree(centroids)
-    _, indices = tree.query(color_grid)
-    lut_flat = centroids[indices]
-    lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3))
-
-    return (PrincipalColorResult(ret, lut), None)
+      if len(box) > 0:
+        box: np.ndarray
+        box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
+        palette.append(box_mean)
+      else:
+        palette.append(np.array([0, 0, 0]))
+    while len(palette) < n_centroids:
+      palette.append(np.array([0, 0, 0]))
+    palette_arr = np.array(palette)
+    tree = cKDTree(palette_arr)
+    _, labels= tree.query(img_f)
+    index_map = labels.reshape((h, w)).astype(np.uint8)
+
+    counts = np.bincount(labels, minlength=n_centroids)
+    total = np.sum(counts)
+    freqs = counts / total if total > 0 else np.zeros(n_centroids)
+    colors = []
+    for i, color in enumerate(palette):
+      colors.append((color.astype(np.uint8), freqs[i]))
+    lut = np.zeros((256, 1, c), dtype=np.uint8)
+    for i, color in enumerate(palette):
+      lut[i, 0, :] = color
+
+    return (PrincipalColorResult(colors, lut, index_map), None)
   return process
 
 
   return process