From: mushcatshiro Date: Sun, 23 Nov 2025 14:19:59 +0000 (+0800) Subject: align median cut X-Git-Url: https://repo.mushcatshiro.com/?a=commitdiff_plain;ds=sidebyside;p=lumos.git align median cut --- diff --git a/lumos/feature/octree.py b/lumos/feature/octree.py new file mode 100644 index 0000000..10221dd --- /dev/null +++ b/lumos/feature/octree.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Dict, List, Optional +import numpy as np + +class Node: + def __init__(self, parent: Optional[Node], idx: int): + self.idx: int = idx + self.parent: Node = parent + self.is_leaf: bool = False + self.r_sum: int = 0 + self.g_sum: int = 0 + self.b_sum: int = 0 + self.childs: Dict[int, Optional[Node]] = {x: None for x in range(8)} + self.pixel_count: int = 0 + self.palette_index: int = -1 + + def update_metadata(self, r:int, g:int, b:int, counts:int): + self.pixel_count += counts + self.r_sum += (r * counts) + self.g_sum += (g * counts) + self.b_sum += (b * counts) + + def merge_child(self, child_idx: int): + child = self.childs[child_idx] + if child is not None: + self.pixel_count += child.pixel_count + self.r_sum += child.r_sum + self.g_sum += child.g_sum + self.b_sum += child.b_sum + self.childs[child_idx] = None + if all(x is None for x in self.childs.values()): + self.is_leaf = True + +class Octree: + def __init__(self, palette_size: int): + self.root: Node = Node(None, -1) + self.palette_size: int = palette_size + self.max_depth:int = 8 + self.layer_map: List[List[Node]] = [[] for _ in range(8)] + self.leaf_count: int = 0 + + def prune(self): + prune_depth = self.max_depth - 1 + while self.leaf_count > self.palette_size: + if len(self.layer_map[prune_depth]) == 0: + prune_depth -= 1 + if prune_depth < 0: + break + continue + self.layer_map[prune_depth].sort(key=lambda x: x.pixel_count, reverse=True) + while len(self.layer_map[prune_depth]) > 0 and self.leaf_count > self.palette_size: + pnode = self.layer_map[prune_depth].pop() + if pnode.parent is None: # noqa + continue + parent = pnode.parent + self.leaf_count -= 1 + parent.merge_child(pnode.idx) + if parent.is_leaf: + self.leaf_count += 1 + + def add_colors(self, unique_colors: np.ndarray, counts: np.ndarray): + # r: 155, g: 0, b: 100 + for i, c in enumerate(unique_colors.tolist()): + r = format(int(c[0]), "08b") + g = format(int(c[1]), "08b") + b = format(int(c[2]), "08b") + # r: 10011011 , g: 00000000, b: 01100100 + vals = [int(c[0]), int(c[1]), int(c[2])] + count = int(counts[i]) + + cur = self.root + for d in range(8): + bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2) + child = cur.childs[bit_idx] + if child is None: + child = Node(cur, bit_idx) + cur.childs[bit_idx] = child + self.layer_map[d].append(child) + cur = child + cur.is_leaf = True + cur.update_metadata(*vals, counts=count) + self.leaf_count += 1 + self.prune() + return self + + def get_primary_colors(self) -> List[Tuple[np.ndarray, float]]: + palette: List[Tuple[np.ndarray, float]] = [] + counter = 0 + for layer in self.layer_map: + for node in layer: + if node.is_leaf: + node.palette_index = counter + if node.pixel_count > 0: + r = int(node.r_sum / node.pixel_count) + g = int(node.g_sum / node.pixel_count) + b = int(node.g_sum / node.pixel_count) + palette.append(( + np.array([r, g, b]).astype(np.uint8), + node.pixel_count + )) + else: + palette.append((np.array([0, 0, 0]).astype(np.uint8), 0)) + counter += 1 + return palette + + def get_palette_index(self, color_arr: np.ndarray) -> int: + r = format(int(color_arr[0]), "08b") + g = format(int(color_arr[1]), "08b") + b = format(int(color_arr[2]), "08b") + cur = self.root + + for d in range(8): + bit_idx = int(f"{r[d]}{g[d]}{b[d]}", 2) + child = cur.childs[bit_idx] + if child is None: + break + cur = child + if cur.is_leaf: + break + return getattr(cur, 'palette_index', -1) + + def get_lut(self, raw_palette): + full_palette = np.zeros((256, 1, 3), dtype=np.uint8) + for i, color in enumerate(raw_palette): + if i >= 256: + break + full_palette[i, 0, 0] = color[0][0] + full_palette[i, 0, 1] = color[0][1] + full_palette[i, 0, 2] = color[0][2] + return full_palette diff --git a/lumos/feature/principal_color.py b/lumos/feature/principal_color.py index 4e1e5de..9c78b16 100644 --- a/lumos/feature/principal_color.py +++ b/lumos/feature/principal_color.py @@ -72,34 +72,48 @@ def median_cut_extractor(n_centroids): boxes = [img_f] while len(boxes) < n_centroids: largest_box_idx = np.argmax([box.shape[0] for box in boxes]) - c = boxes.pop(largest_box_idx) # h*w, c - max_vals = np.max(c, axis=0) - min_vals = np.min(c, axis=0) + box = boxes.pop(largest_box_idx) # h*w, c + if box.shape[0] == 0: + continue # skip empty box + max_vals = np.max(box, axis=0) + min_vals = np.min(box, axis=0) ranges = max_vals - min_vals max_chan = np.argmax(ranges) # (1, ) - median_val = np.median(c[:, max_chan]) + median_val = np.median(box[:, max_chan]) # no broadcast like op, c[:, max_chan] > median val results in row-wise # boolean mask [T, F, ...] of shape (h*w) - l = c[c[:, max_chan] <= median_val] - r = c[c[:, max_chan] > median_val] + l = box[box[:, max_chan] <= median_val] + r = box[box[:, max_chan] > median_val] boxes.append(l) boxes.append(r) - ret = [] + if len(boxes) > n_centroids * 2: + break + palette = [] for box in boxes: - box: np.ndarray - box_mean = np.rint(box.mean(axis=0)).astype(np.uint8) - ret.append([box_mean, len(box)]) - - centroids = np.asarray([x[0] for x in ret]) - lut_size = 256 - r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size)) - color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3) - tree = cKDTree(centroids) - _, indices = tree.query(color_grid) - lut_flat = centroids[indices] - lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3)) - - return (PrincipalColorResult(ret, lut), None) + if len(box) > 0: + box: np.ndarray + box_mean = np.rint(box.mean(axis=0)).astype(np.uint8) + palette.append(box_mean) + else: + palette.append(np.array([0, 0, 0])) + while len(palette) < n_centroids: + palette.append(np.array([0, 0, 0])) + palette_arr = np.array(palette) + tree = cKDTree(palette_arr) + _, labels= tree.query(img_f) + index_map = labels.reshape((h, w)).astype(np.uint8) + + counts = np.bincount(labels, minlength=n_centroids) + total = np.sum(counts) + freqs = counts / total if total > 0 else np.zeros(n_centroids) + colors = [] + for i, color in enumerate(palette): + colors.append((color.astype(np.uint8), freqs[i])) + lut = np.zeros((256, 1, c), dtype=np.uint8) + for i, color in enumerate(palette): + lut[i, 0, :] = color + + return (PrincipalColorResult(colors, lut, index_map), None) return process