From: mushcatshiro Date: Sun, 23 Nov 2025 13:18:03 +0000 (+0800) Subject: add principal color extract X-Git-Url: https://repo.mushcatshiro.com/?a=commitdiff_plain;h=62289c1d9ed6ab0dfd53a54fef278e638c160271;p=lumos.git add principal color extract --- diff --git a/application/color_extraction.py b/application/color_extraction.py new file mode 100644 index 0000000..34c578e --- /dev/null +++ b/application/color_extraction.py @@ -0,0 +1,23 @@ +from pathlib import Path +import os + +from lumos.core import run_simple_pipeline +from lumos.commons.show import show_image +from lumos.visualize.quantize import show_quantized_pallete +from lumos.visualize.treemap import show_treemap +from lumos.commons.color_space import apply_lut +from lumos.feature.principal_color import ( + kmeans_extractor, median_cut_extractor, octree_extractor +) + +steps = [ + show_image(), + # kmeans_extractor(6), + # median_cut_extractor(6), + octree_extractor(6), + show_quantized_pallete(6), + apply_lut(), + show_treemap(), +] + +run_simple_pipeline(Path(os.environ["IMGPATH"]), steps) diff --git a/lumos/commons/color_space.py b/lumos/commons/color_space.py new file mode 100644 index 0000000..283e31f --- /dev/null +++ b/lumos/commons/color_space.py @@ -0,0 +1,85 @@ +from enum import Enum + + +import cv2 +import numpy as np +import matplotlib.pyplot as plt + + +class ColorSpaceOption(Enum): + RGB = "RGB" + HSV = "HSV" + YCrCb = "YCrCb" + Gamma = "Gamma" + Composite = "Composite" + Grayscale = "Grayscale" + + +class RGBSpace(Enum): + R = 0 + G = 1 + B = 2 + + +class HSVSpace(Enum): + H = 0 + S = 1 + V = 2 + + +class YCrCbSpace(Enum): + Y = 0 + Cr = 1 + Cb = 2 + + +def convert_from_rgb(img: np.ndarray, dst: str) -> np.ndarray: + """ + expects arr to be NHWC + RGB, HSV, YCrCb, grayscale $Y = 0.2125 R + 0.7154 G + 0.0721 B$, + composite 0.33 * R + 0.33 * G + 0.33 * B + gamma correction, luma + + technically we could convert back to RGB then to dst as long as its possible + e.g. composite/grayscale is a oneway conversion + """ + if dst == ColorSpaceOption.HSV.value: + cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + elif dst == ColorSpaceOption.YCrCb.value: + cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb) + else: + return np.zeros(img.shape) + return cvt_img + + +def convert_to_rgb(img: np.ndarray, src: str) -> np.ndarray: + if src == ColorSpaceOption.HSV.value: + cvt_img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + pass + elif src == ColorSpaceOption.YCrCb.value: + cvt_img = cv2.cvtColor(img, cv2.COLOR_YCrCb2RGB) + else: + return np.zeros(img.shape) + return cvt_img + + +def apply_lut(show=True): + def process(res): + """ + img = to_nhwc(img) + img = img.squeeze() + if is_3d_lut: + b, g, r = np.moveaxis(img, -1, 0) + return lut[r, g, b] + else: + return cv2.LUT(img, lut) + """ + clean_lut = res.lut.squeeze() + # reconstructed_img = cv2.LUT(res.index_map, res.lut) + reconstructed_img = clean_lut[res.index_map].astype(np.uint8) + if show: + plt.imshow(reconstructed_img) + plt.show() + return (res, None) + return process + diff --git a/lumos/commons/show.py b/lumos/commons/show.py new file mode 100644 index 0000000..6f0f9e2 --- /dev/null +++ b/lumos/commons/show.py @@ -0,0 +1,33 @@ +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path + +from lumos.core import ImageProcessor + + +def sample_and_show_image(sample: int) -> ImageProcessor: + def process(arr: np.ndarray): + # confirm that there is a txt file without .zarr suffix for title. + # build idx iterator from N, if sample is true then sample from 0th dimension + idxs = np.arange(arr.shape[0]) + if sample > 0 and sample <= len(idxs): + idxs = np.random.choice(idxs, size=sample) + for idx in idxs: + fig, ax = plt.subplots(1, 1) + ax.imshow(arr[idx, :, :, :]) + plt.show() + plt.close() + return arr, None + + return process + + +def show_image(): + def process(img: np.ndarray): + fig, ax = plt.subplots(1, 1) + ax.imshow(img) + plt.show() + plt.close() + return img, None + + return process diff --git a/lumos/feature/principal_color.py b/lumos/feature/principal_color.py new file mode 100644 index 0000000..4e1e5de --- /dev/null +++ b/lumos/feature/principal_color.py @@ -0,0 +1,166 @@ +from typing import List, Tuple + +import numpy as np +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from lumos.commons.shape_cvt import to_nhwc, to_float64 +from lumos.core import BaseResult, validate_input +from lumos.feature.octree import Octree + +class PrincipalColorResult(BaseResult): + def __init__( + self, + colors:List[Tuple[np.ndarray, float]], + lut:np.ndarray, + index_map: np.ndarray + ): + self.colors = colors + self.lut = lut + self.index_map = index_map + + +def kmeans_extractor(centriods: int, sample_size: int=0): + @validate_input(target=np.ndarray) + def process(img: np.ndarray): + if len(img.shape) == 4: + img = img.squeeze(0) + + img_f = to_float64(img) + h, w, c = img_f.shape + img_f = img_f.reshape((h * w, c)).astype(np.float32) + + if sample_size > 0: + idx = np.random.choice(img_f.shape[0], sample_size, replace=False) + sample = img_f[idx] + else: + sample = img_f + + model = KMeans( + n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337 + ) + model.fit(sample) + labels = model.predict(img_f) + palette = model.cluster_centers_.astype(np.uint8) + color_count = np.bincount(labels, minlength=centriods) + total_pixels = np.sum(color_count) + if total_pixels > 0: + color_frequency = color_count / float(total_pixels) + else: + color_frequency = np.zeros(centriods) + + colors = [] + + for i in range(len(palette)): + colors.append((palette[i], color_frequency[i])) + lut = np.zeros((256, 1, c), dtype=np.uint8) + for i, color in enumerate(palette): + lut[i, 0, :] = color + index_map = labels.reshape((h, w)).astype(np.uint8) + + return (PrincipalColorResult(colors, lut, index_map), None) + return process + + +def median_cut_extractor(n_centroids): + @validate_input(target=np.ndarray) + def process(img: np.ndarray): + if len(img.shape) == 4: + img = img.squeeze(0) + img_f = to_float64(img) + h, w, c = img_f.shape + img_f = img_f.reshape((h * w, -1)) # h*w, c + boxes = [img_f] + while len(boxes) < n_centroids: + largest_box_idx = np.argmax([box.shape[0] for box in boxes]) + c = boxes.pop(largest_box_idx) # h*w, c + max_vals = np.max(c, axis=0) + min_vals = np.min(c, axis=0) + ranges = max_vals - min_vals + max_chan = np.argmax(ranges) # (1, ) + median_val = np.median(c[:, max_chan]) + # no broadcast like op, c[:, max_chan] > median val results in row-wise + # boolean mask [T, F, ...] of shape (h*w) + l = c[c[:, max_chan] <= median_val] + r = c[c[:, max_chan] > median_val] + boxes.append(l) + boxes.append(r) + ret = [] + for box in boxes: + box: np.ndarray + box_mean = np.rint(box.mean(axis=0)).astype(np.uint8) + ret.append([box_mean, len(box)]) + + centroids = np.asarray([x[0] for x in ret]) + lut_size = 256 + r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size)) + color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3) + tree = cKDTree(centroids) + _, indices = tree.query(color_grid) + lut_flat = centroids[indices] + lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3)) + + return (PrincipalColorResult(ret, lut), None) + return process + + +def octree_extractor(palette_size: int): + @validate_input(target=np.ndarray) + def process(img: np.ndarray): + """ + rgb to 8bit binary + R 01000000 + G 10000000 + B 00000000 + 1st bit 010 -> 2 -> first layer 3rd node? + 2nd bit 100 -> 4 -> 5th node of first layer 3rd node + """ + pixels_list = img.reshape(-1, 3) # into (H*W, C) + unique_colors, inverse, counts = np.unique( + pixels_list, axis=0, return_counts=True, return_inverse=True + ) + ot = Octree(palette_size) + ot.add_colors(unique_colors, counts) + ret = ot.get_primary_colors() + lut = ot.get_lut(ret) + unique_indices= np.array([ + ot.get_palette_index(c) for c in unique_colors + ]) + index_map_f = unique_indices[inverse] # broadcast + index_map = index_map_f.reshape(img.shape[:2]).astype(np.uint8) + + return(PrincipalColorResult(ret, lut, index_map), None) + return process + + +def pca_extractor(components: int): + @validate_input(target=np.ndarray) + def process(img: np.ndarray): + if len(img.shape) == 4: + img = img.squeeze(0) + + img_f = to_float64(img) + h, w, c = img_f.shape + img_f = img_f.reshape((h * w, c)) + + model = KMeans( + n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337 + ) + labels = model.fit_predict(img_f) + palette = np.array(model.cluster_centers_, dtype=int) + color_count = np.bincount(labels) + color_frequency = color_count / float(np.sum(color_count)) + + colors = [] + + for color, freq in zip(palette, color_frequency): + colors.append((color, freq)) + if c == 3: + arr = np.arange(256).reshape(-1, 1).repeat(3, axis=1) + else: + arr = np.arange(256) + + y = model.predict(arr) + lut = palette[y].astype(np.uint8) + lut = lut.reshape((1, 256, c)) + return + return process diff --git a/lumos/visualize/quantize.py b/lumos/visualize/quantize.py new file mode 100644 index 0000000..d46e711 --- /dev/null +++ b/lumos/visualize/quantize.py @@ -0,0 +1,18 @@ +import numpy as np +import matplotlib.pyplot as plt + +from lumos.feature.principal_color import PrincipalColorResult + +def show_quantized_pallete(n:int): + def process(inp: PrincipalColorResult): + carr = np.asarray([x[0] for x in inp.colors]) + carr = carr.reshape((-1, 6, 3)) + farr = [x[1] for x in inp.colors] + + fig, ax = plt.subplots(1, 1) + ax.imshow(carr.transpose(1, 0, 2)) + plt.show() + plt.close() + + return (inp, None) + return process diff --git a/lumos/visualize/treemap.py b/lumos/visualize/treemap.py new file mode 100644 index 0000000..84c15f9 --- /dev/null +++ b/lumos/visualize/treemap.py @@ -0,0 +1,20 @@ +import squarify +import numpy as np +import matplotlib.pyplot as plt + +from lumos.core import validate_input +from lumos.feature.principal_color import PrincipalColorResult + +def show_treemap(): + @validate_input(target=PrincipalColorResult) + def process(inp: PrincipalColorResult): + """ + expects sizes, labels, colors + """ + sizes = [x[1] for x in inp.colors] + colors = [tuple(np.round(x[0]/255., decimals=2)) for x in inp.colors] + squarify.plot(sizes=sizes, label=colors, alpha=0.6, color=colors) + plt.axis('off') # Turn off axes for a cleaner look + plt.show() + return (inp, None) + return process