--- /dev/null
+from enum import Enum
+
+
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+class ColorSpaceOption(Enum):
+ RGB = "RGB"
+ HSV = "HSV"
+ YCrCb = "YCrCb"
+ Gamma = "Gamma"
+ Composite = "Composite"
+ Grayscale = "Grayscale"
+
+
+class RGBSpace(Enum):
+ R = 0
+ G = 1
+ B = 2
+
+
+class HSVSpace(Enum):
+ H = 0
+ S = 1
+ V = 2
+
+
+class YCrCbSpace(Enum):
+ Y = 0
+ Cr = 1
+ Cb = 2
+
+
+def convert_from_rgb(img: np.ndarray, dst: str) -> np.ndarray:
+ """
+ expects arr to be NHWC
+ RGB, HSV, YCrCb, grayscale $Y = 0.2125 R + 0.7154 G + 0.0721 B$,
+ composite 0.33 * R + 0.33 * G + 0.33 * B
+ gamma correction, luma
+
+ technically we could convert back to RGB then to dst as long as its possible
+ e.g. composite/grayscale is a oneway conversion
+ """
+ if dst == ColorSpaceOption.HSV.value:
+ cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
+ elif dst == ColorSpaceOption.YCrCb.value:
+ cvt_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
+ else:
+ return np.zeros(img.shape)
+ return cvt_img
+
+
+def convert_to_rgb(img: np.ndarray, src: str) -> np.ndarray:
+ if src == ColorSpaceOption.HSV.value:
+ cvt_img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
+ pass
+ elif src == ColorSpaceOption.YCrCb.value:
+ cvt_img = cv2.cvtColor(img, cv2.COLOR_YCrCb2RGB)
+ else:
+ return np.zeros(img.shape)
+ return cvt_img
+
+
+def apply_lut(show=True):
+ def process(res):
+ """
+ img = to_nhwc(img)
+ img = img.squeeze()
+ if is_3d_lut:
+ b, g, r = np.moveaxis(img, -1, 0)
+ return lut[r, g, b]
+ else:
+ return cv2.LUT(img, lut)
+ """
+ clean_lut = res.lut.squeeze()
+ # reconstructed_img = cv2.LUT(res.index_map, res.lut)
+ reconstructed_img = clean_lut[res.index_map].astype(np.uint8)
+ if show:
+ plt.imshow(reconstructed_img)
+ plt.show()
+ return (res, None)
+ return process
+
--- /dev/null
+from typing import List, Tuple
+
+import numpy as np
+from scipy.spatial import cKDTree
+from sklearn.cluster import KMeans
+from lumos.commons.shape_cvt import to_nhwc, to_float64
+from lumos.core import BaseResult, validate_input
+from lumos.feature.octree import Octree
+
+class PrincipalColorResult(BaseResult):
+ def __init__(
+ self,
+ colors:List[Tuple[np.ndarray, float]],
+ lut:np.ndarray,
+ index_map: np.ndarray
+ ):
+ self.colors = colors
+ self.lut = lut
+ self.index_map = index_map
+
+
+def kmeans_extractor(centriods: int, sample_size: int=0):
+ @validate_input(target=np.ndarray)
+ def process(img: np.ndarray):
+ if len(img.shape) == 4:
+ img = img.squeeze(0)
+
+ img_f = to_float64(img)
+ h, w, c = img_f.shape
+ img_f = img_f.reshape((h * w, c)).astype(np.float32)
+
+ if sample_size > 0:
+ idx = np.random.choice(img_f.shape[0], sample_size, replace=False)
+ sample = img_f[idx]
+ else:
+ sample = img_f
+
+ model = KMeans(
+ n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337
+ )
+ model.fit(sample)
+ labels = model.predict(img_f)
+ palette = model.cluster_centers_.astype(np.uint8)
+ color_count = np.bincount(labels, minlength=centriods)
+ total_pixels = np.sum(color_count)
+ if total_pixels > 0:
+ color_frequency = color_count / float(total_pixels)
+ else:
+ color_frequency = np.zeros(centriods)
+
+ colors = []
+
+ for i in range(len(palette)):
+ colors.append((palette[i], color_frequency[i]))
+ lut = np.zeros((256, 1, c), dtype=np.uint8)
+ for i, color in enumerate(palette):
+ lut[i, 0, :] = color
+ index_map = labels.reshape((h, w)).astype(np.uint8)
+
+ return (PrincipalColorResult(colors, lut, index_map), None)
+ return process
+
+
+def median_cut_extractor(n_centroids):
+ @validate_input(target=np.ndarray)
+ def process(img: np.ndarray):
+ if len(img.shape) == 4:
+ img = img.squeeze(0)
+ img_f = to_float64(img)
+ h, w, c = img_f.shape
+ img_f = img_f.reshape((h * w, -1)) # h*w, c
+ boxes = [img_f]
+ while len(boxes) < n_centroids:
+ largest_box_idx = np.argmax([box.shape[0] for box in boxes])
+ c = boxes.pop(largest_box_idx) # h*w, c
+ max_vals = np.max(c, axis=0)
+ min_vals = np.min(c, axis=0)
+ ranges = max_vals - min_vals
+ max_chan = np.argmax(ranges) # (1, )
+ median_val = np.median(c[:, max_chan])
+ # no broadcast like op, c[:, max_chan] > median val results in row-wise
+ # boolean mask [T, F, ...] of shape (h*w)
+ l = c[c[:, max_chan] <= median_val]
+ r = c[c[:, max_chan] > median_val]
+ boxes.append(l)
+ boxes.append(r)
+ ret = []
+ for box in boxes:
+ box: np.ndarray
+ box_mean = np.rint(box.mean(axis=0)).astype(np.uint8)
+ ret.append([box_mean, len(box)])
+
+ centroids = np.asarray([x[0] for x in ret])
+ lut_size = 256
+ r, g, b = np.meshgrid(np.arange(lut_size), np.arange(lut_size), np.arange(lut_size))
+ color_grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)
+ tree = cKDTree(centroids)
+ _, indices = tree.query(color_grid)
+ lut_flat = centroids[indices]
+ lut = lut_flat.reshape((lut_size, lut_size, lut_size, 3))
+
+ return (PrincipalColorResult(ret, lut), None)
+ return process
+
+
+def octree_extractor(palette_size: int):
+ @validate_input(target=np.ndarray)
+ def process(img: np.ndarray):
+ """
+ rgb to 8bit binary
+ R 01000000
+ G 10000000
+ B 00000000
+ 1st bit 010 -> 2 -> first layer 3rd node?
+ 2nd bit 100 -> 4 -> 5th node of first layer 3rd node
+ """
+ pixels_list = img.reshape(-1, 3) # into (H*W, C)
+ unique_colors, inverse, counts = np.unique(
+ pixels_list, axis=0, return_counts=True, return_inverse=True
+ )
+ ot = Octree(palette_size)
+ ot.add_colors(unique_colors, counts)
+ ret = ot.get_primary_colors()
+ lut = ot.get_lut(ret)
+ unique_indices= np.array([
+ ot.get_palette_index(c) for c in unique_colors
+ ])
+ index_map_f = unique_indices[inverse] # broadcast
+ index_map = index_map_f.reshape(img.shape[:2]).astype(np.uint8)
+
+ return(PrincipalColorResult(ret, lut, index_map), None)
+ return process
+
+
+def pca_extractor(components: int):
+ @validate_input(target=np.ndarray)
+ def process(img: np.ndarray):
+ if len(img.shape) == 4:
+ img = img.squeeze(0)
+
+ img_f = to_float64(img)
+ h, w, c = img_f.shape
+ img_f = img_f.reshape((h * w, c))
+
+ model = KMeans(
+ n_clusters=centriods, n_init="auto", init="k-means++", random_state=1337
+ )
+ labels = model.fit_predict(img_f)
+ palette = np.array(model.cluster_centers_, dtype=int)
+ color_count = np.bincount(labels)
+ color_frequency = color_count / float(np.sum(color_count))
+
+ colors = []
+
+ for color, freq in zip(palette, color_frequency):
+ colors.append((color, freq))
+ if c == 3:
+ arr = np.arange(256).reshape(-1, 1).repeat(3, axis=1)
+ else:
+ arr = np.arange(256)
+
+ y = model.predict(arr)
+ lut = palette[y].astype(np.uint8)
+ lut = lut.reshape((1, 256, c))
+ return
+ return process