add principal color extract
authormushcatshiro <yingjieho@hotmail.com>
Sun, 23 Nov 2025 13:18:03 +0000 (21:18 +0800)
committermushcatshiro <yingjieho@hotmail.com>
Sun, 23 Nov 2025 13:18:03 +0000 (21:18 +0800)
application/color_extraction.py [new file with mode: 0644]
lumos/commons/color_space.py [new file with mode: 0644]
lumos/commons/show.py [new file with mode: 0644]
lumos/feature/principal_color.py [new file with mode: 0644]
lumos/visualize/quantize.py [new file with mode: 0644]
lumos/visualize/treemap.py [new file with mode: 0644]

diff --git a/application/color_extraction.py b/application/color_extraction.py
new file mode 100644 (file)
index 0000000..34c578e
--- /dev/null
@@ -0,0 +1,23 @@
+from pathlib import Path
+import os
+
+from lumos.core import run_simple_pipeline
+from lumos.commons.show import show_image
+from lumos.visualize.quantize import show_quantized_pallete
+from lumos.visualize.treemap import show_treemap
+from lumos.commons.color_space import apply_lut
+from lumos.feature.principal_color import (
+  kmeans_extractor, median_cut_extractor, octree_extractor
+)
+
+steps = [
+  show_image(),
+  # kmeans_extractor(6),
+  # median_cut_extractor(6),
+  octree_extractor(6),
+  show_quantized_pallete(6),
+  apply_lut(),
+  show_treemap(),
+]
+
+run_simple_pipeline(Path(os.environ["IMGPATH"]), steps)
diff --git a/lumos/commons/color_space.py b/lumos/commons/color_space.py
new file mode 100644 (file)
index 0000000..283e31f
--- /dev/null
@@ -0,0 +1,85 @@
+from enum import Enum
+
+
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+class ColorSpaceOption(Enum):
+  RGB = "RGB"
+  HSV = "HSV"
+  YCrCb = "YCrCb"
+  Gamma = "Gamma"
+  Composite = "Composite"
+  Grayscale = "Grayscale"
+
+
+class RGBSpace(Enum):
+  R = 0
+  G = 1
+  B = 2
+
+
+class HSVSpace(Enum):
+  H = 0
+  S = 1
+  V = 2
+
+
+class YCrCbSpace(Enum):
+  Y = 0
+  Cr = 1
+  Cb = 2
+
+
+def convert_from_rgb(img: np.ndarray, dst: str) -> np.ndarray:
+  """
+  expects arr to be NHWC
+  RGB, HSV, YCrCb, grayscale $Y = 0.2125 R + 0.7154 G + 0.0721 B$,
+  composite 0.33 * R + 0.33 * G + 0.33 * B
+  gamma correction, luma
+
+  technically we could convert back to RGB then to dst as long as its possible
+  e.g. composite/grayscale is a oneway conversion
+  """
+  if dst == ColorSpaceOption.HSV.value:
+    cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
+  elif dst == ColorSpaceOption.YCrCb.value:
+    cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
+  else:
+    return np.zeros(img.shape)
+  return cvt_img
+
+
+def convert_to_rgb(img: np.ndarray, src: str) -> np.ndarray:
+  if src == ColorSpaceOption.HSV.value:
+    cvt_img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+    pass
+  elif src == ColorSpaceOption.YCrCb.value:
+    cvt_img = cv2.cvtColor(img, cv2.COLOR_YCrCb2RGB)
+  else:
+    return np.zeros(img.shape)
+  return cvt_img
+
+
+def apply_lut(show=True):
+  def process(res):
+    """
+    img = to_nhwc(img)
+    img = img.squeeze()
+    if is_3d_lut:
+      b, g, r = np.moveaxis(img, -1, 0)
+      return lut[r, g, b]
+    else:
+      return cv2.LUT(img, lut)
+    """
+    clean_lut = res.lut.squeeze()
+    # reconstructed_img = cv2.LUT(res.index_map, res.lut)
+    reconstructed_img = clean_lut[res.index_map].astype(np.uint8)
+    if show:
+      plt.imshow(reconstructed_img)
+      plt.show()
+    return (res, None)
+  return process
+
diff --git a/lumos/commons/show.py b/lumos/commons/show.py
new file mode 100644 (file)
index 0000000..6f0f9e2
--- /dev/null
@@ -0,0 +1,33 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from pathlib import Path
+
+from lumos.core import ImageProcessor
+
+
+def sample_and_show_image(sample: int) -> ImageProcessor:
+  def process(arr: np.ndarray):
+    # confirm that there is a txt file without .zarr suffix for title.
+    # build idx iterator from N, if sample is true then sample from 0th dimension
+    idxs = np.arange(arr.shape[0])
+    if sample > 0 and sample <= len(idxs):
+      idxs = np.random.choice(idxs, size=sample)
+    for idx in idxs:
+      fig, ax = plt.subplots(1, 1)
+      ax.imshow(arr[idx, :, :, :])
+      plt.show()
+      plt.close()
+    return arr, None
+
+  return process
+
+
+def show_image():
+  def process(img: np.ndarray):
+    fig, ax = plt.subplots(1, 1)
+    ax.imshow(img)
+    plt.show()
+    plt.close()
+    return img, None
+
+  return process
diff --git a/lumos/feature/principal_color.py b/lumos/feature/principal_color.py
new file mode 100644 (file)
index 0000000..4e1e5de
--- /dev/null
@@ -0,0 +1,166 @@
+from typing import List, Tuple
+
+import numpy as np
+from scipy.spatial import cKDTree
+from sklearn.cluster import KMeans
+from lumos.commons.shape_cvt import to_nhwc, to_float64
+from lumos.core import BaseResult, validate_input
+from lumos.feature.octree import Octree
+
+class PrincipalColorResult(BaseResult):
+  def __init__(
+      self,
+      colors:List[Tuple[np.ndarray, float]],
+      lut:np.ndarray,
+      index_map: np.ndarray
+  ):
+    self.colors = colors
+    self.lut = lut
+    self.index_map = index_map
+
+
+def kmeans_extractor(centriods: int, sample_size: int=0):
+  @validate_input(target=np.ndarray)
+  def process(img: np.ndarray):
+    if len(img.shape) == 4:
+      img = img.squeeze(0)
+
+    img_f = to_float64(img)
+    h, w, c = img_f.shape
+    img_f = img_f.reshape((h * w, c)).astype(np.float32)
+
+    if sample_size > 0:
+      idx = np.random.choice(img_f.shape[0], sample_size, replace=False)
+      sample = img_f[idx]
+    else:
+      sample = img_f
+
+    model = KMeans(
+      n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337
+    )
+    model.fit(sample)
+    labels = model.predict(img_f)
+    palette = model.cluster_centers_.astype(np.uint8)
+    color_count = np.bincount(labels, minlength=centriods)
+    total_pixels = np.sum(color_count)
+    if total_pixels > 0:
+      color_frequency = color_count / float(total_pixels)
+    else:
+      color_frequency = np.zeros(centriods)
+
+    colors = []
+
+    for i in range(len(palette)):
+      colors.append((palette[i], color_frequency[i]))
+    lut = np.zeros((256, 1, c), dtype=np.uint8)
+    for i, color in enumerate(palette):
+      lut[i, 0, :] = color
+    index_map = labels.reshape((h, w)).astype(np.uint8)
+
+    return (PrincipalColorResult(colors, lut, index_map), None)
+  return process
+
+
+def median_cut_extractor(n_centroids):
+  @validate_input(target=np.ndarray)
+  def process(img: np.ndarray):
+    if len(img.shape) == 4:
+      img = img.squeeze(0)
+    img_f = to_float64(img)
+    h, w, c = img_f.shape
+    img_f = img_f.reshape((h * w, -1))  # h*w, c
+    boxes = [img_f]
+    while len(boxes) < n_centroids:
+      largest_box_idx = np.argmax([box.shape[0] for box in boxes])
+      c = boxes.pop(largest_box_idx)  # h*w, c
+      max_vals = np.max(c, axis=0)
+      min_vals = np.min(c, axis=0)
+      ranges = max_vals - min_vals
+      max_chan = np.argmax(ranges)  # (1, )
+      median_val = np.median(c[:, max_chan])
+      # no broadcast like op, c[:, max_chan] > median val results in row-wise
+      # boolean mask [T, F, ...] of shape (h*w)
+      l = c[c[:, max_chan] <= median_val]
+      r = c[c[:, max_chan] > median_val]
+      boxes.append(l)
+      boxes.append(r)
+    ret = []
+    for box in boxes:
+      box: np.ndarray
+      box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
+      ret.append([box_mean, len(box)])
+
+    centroids = np.asarray([x[0] for x in ret])
+    lut_size = 256
+    r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size))
+    color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)
+    tree = cKDTree(centroids)
+    _, indices = tree.query(color_grid)
+    lut_flat = centroids[indices]
+    lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3))
+
+    return (PrincipalColorResult(ret, lut), None)
+  return process
+
+
+def octree_extractor(palette_size: int):
+  @validate_input(target=np.ndarray)
+  def process(img: np.ndarray):
+    """
+    rgb to 8bit binary
+    R 01000000
+    G 10000000
+    B 00000000
+    1st bit 010 -> 2 -> first layer 3rd node?
+    2nd bit 100 -> 4 -> 5th node of first layer 3rd node
+    """
+    pixels_list = img.reshape(-1, 3)  # into (H*W, C)
+    unique_colors, inverse, counts = np.unique(
+      pixels_list, axis=0, return_counts=True, return_inverse=True
+    )
+    ot = Octree(palette_size)
+    ot.add_colors(unique_colors, counts)
+    ret = ot.get_primary_colors()
+    lut = ot.get_lut(ret)
+    unique_indices= np.array([
+      ot.get_palette_index(c) for c in unique_colors
+    ])
+    index_map_f = unique_indices[inverse]  # broadcast
+    index_map = index_map_f.reshape(img.shape[:2]).astype(np.uint8)
+
+    return(PrincipalColorResult(ret, lut, index_map), None)
+  return process
+
+
+def pca_extractor(components: int):
+  @validate_input(target=np.ndarray)
+  def process(img: np.ndarray):
+    if len(img.shape) == 4:
+      img = img.squeeze(0)
+
+    img_f = to_float64(img)
+    h, w, c = img_f.shape
+    img_f = img_f.reshape((h * w, c))
+
+    model = KMeans(
+      n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337
+    )
+    labels = model.fit_predict(img_f)
+    palette = np.array(model.cluster_centers_, dtype=int)
+    color_count = np.bincount(labels)
+    color_frequency = color_count / float(np.sum(color_count))
+
+    colors = []
+
+    for color, freq in zip(palette, color_frequency):
+      colors.append((color, freq))
+    if c == 3:
+      arr = np.arange(256).reshape(-1, 1).repeat(3, axis=1)
+    else:
+      arr = np.arange(256)
+
+    y = model.predict(arr)
+    lut = palette[y].astype(np.uint8)
+    lut = lut.reshape((1, 256, c))
+    return
+  return process
diff --git a/lumos/visualize/quantize.py b/lumos/visualize/quantize.py
new file mode 100644 (file)
index 0000000..d46e711
--- /dev/null
@@ -0,0 +1,18 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+from lumos.feature.principal_color import PrincipalColorResult
+
+def show_quantized_pallete(n:int):
+  def process(inp: PrincipalColorResult):
+    carr = np.asarray([x[0] for x in inp.colors])
+    carr = carr.reshape((-1, 6, 3))
+    farr = [x[1] for x in inp.colors]
+
+    fig, ax = plt.subplots(1, 1)
+    ax.imshow(carr.transpose(1, 0, 2))
+    plt.show()
+    plt.close()
+
+    return (inp, None)
+  return process
diff --git a/lumos/visualize/treemap.py b/lumos/visualize/treemap.py
new file mode 100644 (file)
index 0000000..84c15f9
--- /dev/null
@@ -0,0 +1,20 @@
+import squarify
+import numpy as np
+import matplotlib.pyplot as plt
+
+from lumos.core import validate_input
+from lumos.feature.principal_color import PrincipalColorResult
+
+def show_treemap():
+  @validate_input(target=PrincipalColorResult)
+  def process(inp: PrincipalColorResult):
+    """
+    expects sizes, labels, colors
+    """
+    sizes = [x[1] for x in inp.colors]
+    colors = [tuple(np.round(x[0]/255., decimals=2)) for x in inp.colors]
+    squarify.plot(sizes=sizes, label=colors, alpha=0.6, color=colors)
+    plt.axis('off') # Turn off axes for a cleaner look
+    plt.show()
+    return (inp, None)
+  return process