--- /dev/null
+from __future__ import annotations
+
+from typing import Dict, List, Optional
+import numpy as np
+
+class Node:
+ def __init__(self, parent: Optional[Node], idx: int):
+ self.idx: int = idx
+ self.parent: Node = parent
+ self.is_leaf: bool = False
+ self.r_sum: int = 0
+ self.g_sum: int = 0
+ self.b_sum: int = 0
+ self.childs: Dict[int, Optional[Node]] = {x: None for x in range(8)}
+ self.pixel_count: int = 0
+ self.palette_index: int = -1
+
+ def update_metadata(self, r:int, g:int, b:int, counts:int):
+ self.pixel_count += counts
+ self.r_sum += (r * counts)
+ self.g_sum += (g * counts)
+ self.b_sum += (b * counts)
+
+ def merge_child(self, child_idx: int):
+ child = self.childs[child_idx]
+ if child is not None:
+ self.pixel_count += child.pixel_count
+ self.r_sum += child.r_sum
+ self.g_sum += child.g_sum
+ self.b_sum += child.b_sum
+ self.childs[child_idx] = None
+ if all(x is None for x in self.childs.values()):
+ self.is_leaf = True
+
+class Octree:
+ def __init__(self, palette_size: int):
+ self.root: Node = Node(None, -1)
+ self.palette_size: int = palette_size
+ self.max_depth:int = 8
+ self.layer_map: List[List[Node]] = [[] for _ in range(8)]
+ self.leaf_count: int = 0
+
+ def prune(self):
+ prune_depth = self.max_depth - 1
+ while self.leaf_count > self.palette_size:
+ if len(self.layer_map[prune_depth]) == 0:
+ prune_depth -= 1
+ if prune_depth < 0:
+ break
+ continue
+ self.layer_map[prune_depth].sort(key=lambda x: x.pixel_count, reverse=True)
+ while len(self.layer_map[prune_depth]) > 0 and self.leaf_count > self.palette_size:
+ pnode = self.layer_map[prune_depth].pop()
+ if pnode.parent is None: # noqa
+ continue
+ parent = pnode.parent
+ self.leaf_count -= 1
+ parent.merge_child(pnode.idx)
+ if parent.is_leaf:
+ self.leaf_count += 1
+
+ def add_colors(self, unique_colors: np.ndarray, counts: np.ndarray):
+ # r: 155, g: 0, b: 100
+ for i, c in enumerate(unique_colors.tolist()):
+ r = format(int(c[0]), "08b")
+ g = format(int(c[1]), "08b")
+ b = format(int(c[2]), "08b")
+ # r: 10011011 , g: 00000000, b: 01100100
+ vals = [int(c[0]), int(c[1]), int(c[2])]
+ count = int(counts[i])
+
+ cur = self.root
+ for d in range(8):
+ bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+ child = cur.childs[bit_idx]
+ if child is None:
+ child = Node(cur, bit_idx)
+ cur.childs[bit_idx] = child
+ self.layer_map[d].append(child)
+ cur = child
+ cur.is_leaf = True
+ cur.update_metadata(*vals, counts=count)
+ self.leaf_count += 1
+ self.prune()
+ return self
+
+ def get_primary_colors(self) -> List[Tuple[np.ndarray, float]]:
+ palette: List[Tuple[np.ndarray, float]] = []
+ counter = 0
+ for layer in self.layer_map:
+ for node in layer:
+ if node.is_leaf:
+ node.palette_index = counter
+ if node.pixel_count > 0:
+ r = int(node.r_sum / node.pixel_count)
+ g = int(node.g_sum / node.pixel_count)
+ b = int(node.g_sum / node.pixel_count)
+ palette.append((
+ np.array([r, g, b]).astype(np.uint8),
+ node.pixel_count
+ ))
+ else:
+ palette.append((np.array([0, 0, 0]).astype(np.uint8), 0))
+ counter += 1
+ return palette
+
+ def get_palette_index(self, color_arr: np.ndarray) -> int:
+ r = format(int(color_arr[0]), "08b")
+ g = format(int(color_arr[1]), "08b")
+ b = format(int(color_arr[2]), "08b")
+ cur = self.root
+
+ for d in range(8):
+ bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2)
+ child = cur.childs[bit_idx]
+ if child is None:
+ break
+ cur = child
+ if cur.is_leaf:
+ break
+ return getattr(cur, 'palette_index', -1)
+
+ def get_lut(self, raw_palette):
+ full_palette = np.zeros((256, 1, 3), dtype=np.uint8)
+ for i, color in enumerate(raw_palette):
+ if i >= 256:
+ break
+ full_palette[i, 0, 0] = color[0][0]
+ full_palette[i, 0, 1] = color[0][1]
+ full_palette[i, 0, 2] = color[0][2]
+ return full_palette
boxes = [img_f]
while len(boxes) < n_centroids:
largest_box_idx = np.argmax([box.shape[0] for box in boxes])
- c = boxes.pop(largest_box_idx) # h*w, c
- max_vals = np.max(c, axis=0)
- min_vals = np.min(c, axis=0)
+ box = boxes.pop(largest_box_idx) # h*w, c
+ if box.shape[0] == 0:
+ continue # skip empty box
+ max_vals = np.max(box, axis=0)
+ min_vals = np.min(box, axis=0)
ranges = max_vals - min_vals
max_chan = np.argmax(ranges) # (1, )
- median_val = np.median(c[:, max_chan])
+ median_val = np.median(box[:, max_chan])
# no broadcast like op, c[:, max_chan] > median val results in row-wise
# boolean mask [T, F, ...] of shape (h*w)
- l = c[c[:, max_chan] <= median_val]
- r = c[c[:, max_chan] > median_val]
+ l = box[box[:, max_chan] <= median_val]
+ r = box[box[:, max_chan] > median_val]
boxes.append(l)
boxes.append(r)
- ret = []
+ if len(boxes) > n_centroids * 2:
+ break
+ palette = []
for box in boxes:
- box: np.ndarray
- box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
- ret.append([box_mean, len(box)])
-
- centroids = np.asarray([x[0] for x in ret])
- lut_size = 256
- r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size))
- color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)
- tree = cKDTree(centroids)
- _, indices = tree.query(color_grid)
- lut_flat = centroids[indices]
- lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3))
-
- return (PrincipalColorResult(ret, lut), None)
+ if len(box) > 0:
+ box: np.ndarray
+ box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
+ palette.append(box_mean)
+ else:
+ palette.append(np.array([0, 0, 0]))
+ while len(palette) < n_centroids:
+ palette.append(np.array([0, 0, 0]))
+ palette_arr = np.array(palette)
+ tree = cKDTree(palette_arr)
+ _, labels= tree.query(img_f)
+ index_map = labels.reshape((h, w)).astype(np.uint8)
+
+ counts = np.bincount(labels, minlength=n_centroids)
+ total = np.sum(counts)
+ freqs = counts / total if total > 0 else np.zeros(n_centroids)
+ colors = []
+ for i, color in enumerate(palette):
+ colors.append((color.astype(np.uint8), freqs[i]))
+ lut = np.zeros((256, 1, c), dtype=np.uint8)
+ for i, color in enumerate(palette):
+ lut[i, 0, :] = color
+
+ return (PrincipalColorResult(colors, lut, index_map), None)
return process