diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 43530efca28..e2739d181a6 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -2083,6 +2083,24 @@
"pullBboxIntoLayerError": "Problem Pulling BBox Into Layer",
"pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage",
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
+ "addAdjustments": "Add Adjustments",
+ "removeAdjustments": "Remove Adjustments",
+ "adjustments": {
+ "simple": "Simple",
+ "curves": "Curves",
+ "heading": "Adjustments",
+ "expand": "Expand adjustments",
+ "collapse": "Collapse adjustments",
+ "brightness": "Brightness",
+ "contrast": "Contrast",
+ "saturation": "Saturation",
+ "temperature": "Temperature",
+ "tint": "Tint",
+ "sharpness": "Sharpness",
+ "finish": "Finish",
+ "reset": "Reset",
+ "master": "Master"
+ },
"regionIsEmpty": "Selected region is empty",
"mergeVisible": "Merge Visible",
"mergeDown": "Merge Down",
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
index ddaefb1073e..13dc30dea20 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
@@ -4,6 +4,7 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
+import { RasterLayerAdjustmentsPanel } from 'features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel';
import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate';
import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -39,6 +40,7 @@ export const RasterLayer = memo(({ id }: Props) => {
+
{
+ const { t } = useTranslation();
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const canvasManager = useCanvasManager();
+
+ const selectHasAdjustments = useMemo(() => {
+ return createSelector(selectCanvasSlice, (canvas) => Boolean(selectEntity(canvas, entityIdentifier)?.adjustments));
+ }, [entityIdentifier]);
+
+ const hasAdjustments = useAppSelector(selectHasAdjustments);
+
+ const selectMode = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.mode ?? 'simple'
+ );
+ }, [entityIdentifier]);
+ const mode = useAppSelector(selectMode);
+
+ const selectEnabled = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled ?? false
+ );
+ }, [entityIdentifier]);
+ const enabled = useAppSelector(selectEnabled);
+
+ const selectCollapsed = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.collapsed ?? false
+ );
+ }, [entityIdentifier]);
+ const collapsed = useAppSelector(selectCollapsed);
+
+ const onToggleEnabled = useCallback(() => {
+ dispatch(rasterLayerAdjustmentsEnabledToggled({ entityIdentifier }));
+ }, [dispatch, entityIdentifier]);
+
+ const onReset = useCallback(() => {
+ // Reset values to defaults but keep adjustments present; preserve enabled/collapsed/mode
+ dispatch(rasterLayerAdjustmentsReset({ entityIdentifier }));
+ }, [dispatch, entityIdentifier]);
+
+ const onCancel = useCallback(() => {
+ // Clear out adjustments entirely
+ dispatch(rasterLayerAdjustmentsCancel({ entityIdentifier }));
+ }, [dispatch, entityIdentifier]);
+
+ const onToggleCollapsed = useCallback(() => {
+ dispatch(rasterLayerAdjustmentsCollapsedToggled({ entityIdentifier }));
+ }, [dispatch, entityIdentifier]);
+
+ const onClickModeSimple = useCallback(
+ () => dispatch(rasterLayerAdjustmentsModeChanged({ entityIdentifier, mode: 'simple' })),
+ [dispatch, entityIdentifier]
+ );
+
+ const onClickModeCurves = useCallback(
+ () => dispatch(rasterLayerAdjustmentsModeChanged({ entityIdentifier, mode: 'curves' })),
+ [dispatch, entityIdentifier]
+ );
+
+ const onFinish = useCallback(async () => {
+ // Bake current visual into layer pixels, then clear adjustments
+ const adapter = canvasManager.getAdapter(entityIdentifier);
+ if (!adapter || adapter.type !== 'raster_layer_adapter') {
+ return;
+ }
+ const rect = adapter.transformer.getRelativeRect();
+ try {
+ await adapter.renderer.rasterize({ rect, replaceObjects: true });
+ // Clear adjustments after baking
+ dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null }));
+ } catch {
+ // no-op; leave state unchanged on failure
+ }
+ }, [canvasManager, entityIdentifier, dispatch]);
+
+ // Hide the panel entirely until adjustments are added via context menu
+ if (!hasAdjustments) {
+ return null;
+ }
+
+ return (
+ <>
+
+
+ }
+ />
+
+ Adjustments
+
+
+
+
+
+
+ }
+ variant="ghost"
+ />
+ }
+ variant="ghost"
+ />
+ }
+ variant="ghost"
+ />
+
+
+ {!collapsed && mode === 'simple' && }
+
+ {!collapsed && mode === 'curves' && }
+ >
+ );
+});
+
+RasterLayerAdjustmentsPanel.displayName = 'RasterLayerAdjustmentsPanel';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx
new file mode 100644
index 00000000000..9610927e016
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx
@@ -0,0 +1,179 @@
+import { Box, Flex } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice';
+import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
+import type { ChannelName, ChannelPoints, CurvesAdjustmentsConfig } from 'features/controlLayers/store/types';
+import { memo, useCallback, useEffect, useMemo, useState } from 'react';
+import { useTranslation } from 'react-i18next';
+
+import { RasterLayerCurvesAdjustmentsGraph } from './RasterLayerCurvesAdjustmentsGraph';
+
+const DEFAULT_POINTS: ChannelPoints = [
+ [0, 0],
+ [255, 255],
+];
+
+const DEFAULT_CURVES: CurvesAdjustmentsConfig = {
+ master: DEFAULT_POINTS,
+ r: DEFAULT_POINTS,
+ g: DEFAULT_POINTS,
+ b: DEFAULT_POINTS,
+};
+
+type ChannelHistograms = Record;
+
+const calculateHistogramsFromImageData = (imageData: ImageData): ChannelHistograms | null => {
+ try {
+ const data = imageData.data;
+ const len = data.length / 4;
+ const master = new Array(256).fill(0);
+ const r = new Array(256).fill(0);
+ const g = new Array(256).fill(0);
+ const b = new Array(256).fill(0);
+ // sample every 4th pixel to lighten work
+ for (let i = 0; i < len; i += 4) {
+ const idx = i * 4;
+ const rv = data[idx] as number;
+ const gv = data[idx + 1] as number;
+ const bv = data[idx + 2] as number;
+ const m = Math.round(0.2126 * rv + 0.7152 * gv + 0.0722 * bv);
+ if (m >= 0 && m < 256) {
+ master[m] = (master[m] ?? 0) + 1;
+ }
+ if (rv >= 0 && rv < 256) {
+ r[rv] = (r[rv] ?? 0) + 1;
+ }
+ if (gv >= 0 && gv < 256) {
+ g[gv] = (g[gv] ?? 0) + 1;
+ }
+ if (bv >= 0 && bv < 256) {
+ b[bv] = (b[bv] ?? 0) + 1;
+ }
+ }
+ return {
+ master,
+ r,
+ g,
+ b,
+ };
+ } catch {
+ return null;
+ }
+};
+
+export const RasterLayerCurvesAdjustmentsEditor = memo(() => {
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer');
+ const { t } = useTranslation();
+ const selectCurves = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.curves ?? DEFAULT_CURVES
+ );
+ }, [entityIdentifier]);
+ const curves = useAppSelector(selectCurves);
+
+ const selectIsDisabled = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled !== true
+ );
+ }, [entityIdentifier]);
+ const isDisabled = useAppSelector(selectIsDisabled);
+ // The canvas cache for the layer serves as a proxy for when the layer changes and can be used to trigger histo recalc
+ const canvasCache = useStore(adapter.$canvasCache);
+
+ const [histMaster, setHistMaster] = useState(null);
+ const [histR, setHistR] = useState(null);
+ const [histG, setHistG] = useState(null);
+ const [histB, setHistB] = useState(null);
+
+ const recalcHistogram = useCallback(() => {
+ try {
+ const rect = adapter.transformer.getRelativeRect();
+ if (rect.width === 0 || rect.height === 0) {
+ setHistMaster(Array(256).fill(0));
+ setHistR(Array(256).fill(0));
+ setHistG(Array(256).fill(0));
+ setHistB(Array(256).fill(0));
+ return;
+ }
+ const imageData = adapter.renderer.getImageData({ rect });
+ const h = calculateHistogramsFromImageData(imageData);
+ if (h) {
+ setHistMaster(h.master);
+ setHistR(h.r);
+ setHistG(h.g);
+ setHistB(h.b);
+ }
+ } catch {
+ // ignore
+ }
+ }, [adapter]);
+
+ useEffect(() => {
+ recalcHistogram();
+ }, [canvasCache, recalcHistogram]);
+
+ const onChangePoints = useCallback(
+ (channel: ChannelName, pts: ChannelPoints) => {
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts }));
+ },
+ [dispatch, entityIdentifier]
+ );
+
+ // Memoize per-channel change handlers to avoid inline lambdas in JSX
+ const onChangeMaster = useCallback((pts: ChannelPoints) => onChangePoints('master', pts), [onChangePoints]);
+ const onChangeR = useCallback((pts: ChannelPoints) => onChangePoints('r', pts), [onChangePoints]);
+ const onChangeG = useCallback((pts: ChannelPoints) => onChangePoints('g', pts), [onChangePoints]);
+ const onChangeB = useCallback((pts: ChannelPoints) => onChangePoints('b', pts), [onChangePoints]);
+
+ return (
+
+
+
+
+
+
+
+
+ );
+});
+
+RasterLayerCurvesAdjustmentsEditor.displayName = 'RasterLayerCurvesEditor';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx
new file mode 100644
index 00000000000..d8166ef686f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx
@@ -0,0 +1,432 @@
+import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
+import type { ChannelName, ChannelPoints } from 'features/controlLayers/store/types';
+import React, { memo, useCallback, useEffect, useRef, useState } from 'react';
+import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
+
+const DEFAULT_POINTS: ChannelPoints = [
+ [0, 0],
+ [255, 255],
+];
+
+const channelColor: Record = {
+ master: '#888',
+ r: '#e53e3e',
+ g: '#38a169',
+ b: '#3182ce',
+};
+
+const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v);
+
+const sortPoints = (pts: ChannelPoints) =>
+ [...pts]
+ .sort((a, b) => {
+ const xDiff = a[0] - b[0];
+ if (xDiff) {
+ return xDiff;
+ }
+ if (a[0] === 0 || a[0] === 255) {
+ return a[1] - b[1];
+ }
+ return 0;
+ })
+ // Finally, clamp to valid range and round to integers
+ .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] satisfies [number, number]);
+
+// Base canvas logical coordinate system (used for aspect ratio & initial sizing)
+const CANVAS_WIDTH = 256;
+const CANVAS_HEIGHT = 160;
+const MARGIN_LEFT = 8;
+const MARGIN_RIGHT = 8;
+const MARGIN_TOP = 8;
+const MARGIN_BOTTOM = 10;
+
+const CANVAS_STYLE: React.CSSProperties = {
+ width: '100%',
+ // Maintain aspect ratio while allowing responsive width. Height is set automatically via aspect-ratio.
+ aspectRatio: `${CANVAS_WIDTH} / ${CANVAS_HEIGHT}`,
+ height: 'auto',
+ touchAction: 'none',
+ borderRadius: 4,
+ background: '#111',
+ display: 'block',
+};
+
+type CurveGraphProps = {
+ title: string;
+ channel: ChannelName;
+ points: ChannelPoints | undefined;
+ histogram: number[] | null;
+ onChange: (pts: ChannelPoints) => void;
+};
+
+const drawHistogram = (
+ c: HTMLCanvasElement,
+ channel: ChannelName,
+ histogram: number[] | null,
+ points: ChannelPoints
+) => {
+ // Use device pixel ratio for crisp rendering on HiDPI displays.
+ const dpr = window.devicePixelRatio || 1;
+ const cssWidth = c.clientWidth || CANVAS_WIDTH; // CSS pixels
+ const cssHeight = (cssWidth * CANVAS_HEIGHT) / CANVAS_WIDTH; // maintain aspect ratio
+
+ // Ensure the backing store matches current display size * dpr (only if changed).
+ const targetWidth = Math.round(cssWidth * dpr);
+ const targetHeight = Math.round(cssHeight * dpr);
+ if (c.width !== targetWidth || c.height !== targetHeight) {
+ c.width = targetWidth;
+ c.height = targetHeight;
+ }
+ // Guarantee the CSS height stays synced (width is 100%).
+ if (c.style.height !== `${cssHeight}px`) {
+ c.style.height = `${cssHeight}px`;
+ }
+
+ const ctx = c.getContext('2d');
+ if (!ctx) {
+ return;
+ }
+
+ // Reset transform then scale for dpr so we can draw in CSS pixel coordinates.
+ ctx.setTransform(1, 0, 0, 1, 0, 0);
+ ctx.scale(dpr, dpr);
+
+ // Dynamic inner geometry (CSS pixel space)
+ const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT;
+ const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM;
+
+ const valueToCanvasX = (x: number) => MARGIN_LEFT + (clamp(x, 0, 255) / 255) * innerWidth;
+ const valueToCanvasY = (y: number) => MARGIN_TOP + innerHeight - (clamp(y, 0, 255) / 255) * innerHeight;
+
+ // Clear & background
+ ctx.clearRect(0, 0, cssWidth, cssHeight);
+ ctx.fillStyle = '#111';
+ ctx.fillRect(0, 0, cssWidth, cssHeight);
+
+ // Grid
+ ctx.strokeStyle = '#2a2a2a';
+ ctx.lineWidth = 1;
+ for (let i = 0; i <= 4; i++) {
+ const y = MARGIN_TOP + (i * innerHeight) / 4;
+ ctx.beginPath();
+ ctx.moveTo(MARGIN_LEFT + 0.5, y + 0.5);
+ ctx.lineTo(MARGIN_LEFT + innerWidth - 0.5, y + 0.5);
+ ctx.stroke();
+ }
+ for (let i = 0; i <= 4; i++) {
+ const x = MARGIN_LEFT + (i * innerWidth) / 4;
+ ctx.beginPath();
+ ctx.moveTo(x + 0.5, MARGIN_TOP + 0.5);
+ ctx.lineTo(x + 0.5, MARGIN_TOP + innerHeight - 0.5);
+ ctx.stroke();
+ }
+
+ // Histogram
+ if (histogram) {
+ const logHist = histogram.map((v) => Math.log10((v ?? 0) + 1));
+ const max = Math.max(1e-6, ...logHist);
+ ctx.fillStyle = '#5557';
+
+ // If there's enough horizontal room, draw each of the 256 bins with exact (possibly fractional) width so they tessellate.
+ // Otherwise, aggregate multiple bins into per-pixel columns to avoid aliasing.
+ if (innerWidth >= 256) {
+ for (let i = 0; i < 256; i++) {
+ const v = logHist[i] ?? 0;
+ const h = (v / max) * (innerHeight - 2);
+ // Exact fractional coordinates for seamless coverage (no gaps as width grows)
+ const x0 = MARGIN_LEFT + (i / 256) * innerWidth;
+ const x1 = MARGIN_LEFT + ((i + 1) / 256) * innerWidth;
+ const w = x1 - x0;
+ if (w <= 0) {
+ continue;
+ } // safety
+ const y = MARGIN_TOP + innerHeight - h;
+ ctx.fillRect(x0, y, w, h);
+ }
+ } else {
+ // Aggregate bins per CSS pixel column (similar to previous anti-moire approach)
+ const columns = Math.max(1, Math.round(innerWidth));
+ const binsPerCol = 256 / columns;
+ for (let col = 0; col < columns; col++) {
+ const startBin = Math.floor(col * binsPerCol);
+ const endBin = Math.min(255, Math.floor((col + 1) * binsPerCol - 1));
+ let acc = 0;
+ let count = 0;
+ for (let b = startBin; b <= endBin; b++) {
+ acc += logHist[b] ?? 0;
+ count++;
+ }
+ const v = count > 0 ? acc / count : 0;
+ const h = (v / max) * (innerHeight - 2);
+ const x = MARGIN_LEFT + col;
+ const y = MARGIN_TOP + innerHeight - h;
+ ctx.fillRect(x, y, 1, h);
+ }
+ }
+ }
+
+ // Curve
+ const pts = sortPoints(points);
+ ctx.strokeStyle = channelColor[channel];
+ ctx.lineWidth = 2;
+ ctx.beginPath();
+ for (let i = 0; i < pts.length; i++) {
+ const [x, y] = pts[i]!;
+ const cx = valueToCanvasX(x);
+ const cy = valueToCanvasY(y);
+ if (i === 0) {
+ ctx.moveTo(cx, cy);
+ } else {
+ ctx.lineTo(cx, cy);
+ }
+ }
+ ctx.stroke();
+
+ // Control points
+ for (let i = 0; i < pts.length; i++) {
+ const [x, y] = pts[i]!;
+ const cx = valueToCanvasX(x);
+ const cy = valueToCanvasY(y);
+ ctx.fillStyle = '#000';
+ ctx.beginPath();
+ ctx.arc(cx, cy, 3.5, 0, Math.PI * 2);
+ ctx.fill();
+ ctx.strokeStyle = channelColor[channel];
+ ctx.lineWidth = 1.5;
+ ctx.stroke();
+ }
+};
+
+const getNearestPointIndex = (c: HTMLCanvasElement, points: ChannelPoints, mx: number, my: number) => {
+ const cssWidth = c.clientWidth || CANVAS_WIDTH;
+ const cssHeight = c.clientHeight || CANVAS_HEIGHT;
+ const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT;
+ const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM;
+ const canvasToValueX = (cx: number) => clamp(Math.round(((cx - MARGIN_LEFT) / innerWidth) * 255), 0, 255);
+ const canvasToValueY = (cy: number) => clamp(Math.round(255 - ((cy - MARGIN_TOP) / innerHeight) * 255), 0, 255);
+ const xVal = canvasToValueX(mx);
+ const yVal = canvasToValueY(my);
+ let best = -1;
+ let bestDist = 9999;
+ for (let i = 0; i < points.length; i++) {
+ const [px, py] = points[i]!;
+ const dx = px - xVal;
+ const dy = py - yVal;
+ const d = dx * dx + dy * dy;
+ if (d < bestDist) {
+ best = i;
+ bestDist = d;
+ }
+ }
+ if (best !== -1 && bestDist <= 20 * 20) {
+ return best;
+ }
+ return -1;
+};
+
+const canvasXToValueX = (c: HTMLCanvasElement, cx: number): number => {
+ const cssWidth = c.clientWidth || CANVAS_WIDTH;
+ const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT;
+ return clamp(Math.round(((cx - MARGIN_LEFT) / innerWidth) * 255), 0, 255);
+};
+
+const canvasYToValueY = (c: HTMLCanvasElement, cy: number) => {
+ const cssHeight = c.clientHeight || CANVAS_HEIGHT;
+ const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM;
+ return clamp(Math.round(255 - ((cy - MARGIN_TOP) / innerHeight) * 255), 0, 255);
+};
+
+export const RasterLayerCurvesAdjustmentsGraph = memo((props: CurveGraphProps) => {
+ const { title, channel, points, histogram, onChange } = props;
+ const canvasRef = useRef(null);
+ const [localPoints, setLocalPoints] = useState(sortPoints(points ?? DEFAULT_POINTS));
+ const [dragIndex, setDragIndex] = useState(null);
+
+ useEffect(() => {
+ setLocalPoints(sortPoints(points ?? DEFAULT_POINTS));
+ }, [points]);
+
+ useEffect(() => {
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ drawHistogram(c, channel, histogram, localPoints);
+ }, [channel, histogram, localPoints]);
+
+ const handlePointerDown = useCallback(
+ (e: React.PointerEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ // Capture the pointer so we still get pointerup even if released outside the canvas.
+ try {
+ c.setPointerCapture(e.pointerId);
+ } catch {
+ /* ignore */
+ }
+ const rect = c.getBoundingClientRect();
+ const mx = e.clientX - rect.left; // CSS pixel coordinates
+ const my = e.clientY - rect.top;
+ const idx = getNearestPointIndex(c, localPoints, mx, my);
+ if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) {
+ setDragIndex(idx);
+ return;
+ }
+ const xVal = canvasXToValueX(c, mx);
+ const yVal = canvasYToValueY(c, my);
+ const next = sortPoints([...localPoints, [xVal, yVal]]);
+ setLocalPoints(next);
+ setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal));
+ },
+ [localPoints]
+ );
+
+ const handlePointerMove = useCallback(
+ (e: React.PointerEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ if (dragIndex === null) {
+ return;
+ }
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const rect = c.getBoundingClientRect();
+ const mx = e.clientX - rect.left;
+ const my = e.clientY - rect.top;
+ const mxVal = canvasXToValueX(c, mx);
+ const myVal = canvasYToValueY(c, my);
+ setLocalPoints((prev) => {
+ // Endpoints are immutable; safety check.
+ if (dragIndex === 0 || dragIndex === prev.length - 1) {
+ return prev;
+ }
+ const leftX = prev[dragIndex - 1]![0];
+ const rightX = prev[dragIndex + 1]![0];
+ // Constrain to strictly between neighbors so ordering is preserved & no crossing.
+ const minX = Math.min(254, leftX);
+ const maxX = Math.max(1, rightX);
+ const clampedX = clamp(mxVal, minX, maxX);
+ // If neighbors are adjacent (minX > maxX after adjustments), effectively lock X.
+ const finalX = minX > maxX ? leftX + 1 - 1 /* keep existing */ : clampedX;
+ const next = [...prev];
+ next[dragIndex] = [finalX, myVal];
+ return next; // already ordered due to constraints
+ });
+ },
+ [dragIndex]
+ );
+
+ const commit = useCallback(
+ (pts: ChannelPoints) => {
+ onChange(sortPoints(pts));
+ },
+ [onChange]
+ );
+
+ const handlePointerUp = useCallback(
+ (e: React.PointerEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ const c = canvasRef.current;
+ if (c) {
+ try {
+ c.releasePointerCapture(e.pointerId);
+ } catch {
+ /* ignore */
+ }
+ }
+ setDragIndex(null);
+ commit(localPoints);
+ },
+ [commit, localPoints]
+ );
+
+ const handlePointerCancel = useCallback(
+ (e: React.PointerEvent) => {
+ const c = canvasRef.current;
+ if (c) {
+ try {
+ c.releasePointerCapture(e.pointerId);
+ } catch {
+ /* ignore */
+ }
+ }
+ setDragIndex(null);
+ commit(localPoints);
+ },
+ [commit, localPoints]
+ );
+
+ const handleDoubleClick = useCallback(
+ (e: React.MouseEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const rect = c.getBoundingClientRect();
+ const mx = e.clientX - rect.left;
+ const my = e.clientY - rect.top;
+ const idx = getNearestPointIndex(c, localPoints, mx, my);
+ if (idx > 0 && idx < localPoints.length - 1) {
+ const next = localPoints.filter((_, i) => i !== idx);
+ setLocalPoints(next);
+ commit(next);
+ }
+ },
+ [commit, localPoints]
+ );
+
+ // Observe size changes to redraw (responsive behavior)
+ useEffect(() => {
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const ro = new ResizeObserver(() => {
+ drawHistogram(c, channel, histogram, localPoints);
+ });
+ ro.observe(c);
+ return () => ro.disconnect();
+ }, [channel, histogram, localPoints]);
+
+ const resetPoints = useCallback(() => {
+ setLocalPoints(sortPoints(DEFAULT_POINTS));
+ commit(DEFAULT_POINTS);
+ }, [commit]);
+
+ return (
+
+
+
+ {title}
+
+ }
+ aria-label="Reset"
+ size="sm"
+ variant="link"
+ onClick={resetPoints}
+ />
+
+
+
+ );
+});
+
+RasterLayerCurvesAdjustmentsGraph.displayName = 'RasterLayerCurvesAdjustmentsGraph';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
index 65a16a7b4f9..708f7f29cd6 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
@@ -9,6 +9,7 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
+import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments';
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
import { memo } from 'react';
@@ -21,10 +22,10 @@ export const RasterLayerMenuItems = memo(() => {
-
+
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx
new file mode 100644
index 00000000000..86fac78cb3e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx
@@ -0,0 +1,39 @@
+import { MenuItem } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { rasterLayerAdjustmentsCancel, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice';
+import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
+import { makeDefaultRasterLayerAdjustments } from 'features/controlLayers/store/util';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiSlidersHorizontalBold } from 'react-icons/pi';
+
+export const RasterLayerMenuItemsAdjustments = memo(() => {
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const { t } = useTranslation();
+ const layer = useAppSelector((s) =>
+ s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id)
+ );
+ const hasAdjustments = Boolean(layer?.adjustments);
+ const onToggleAdjustmentsPresence = useCallback(() => {
+ if (hasAdjustments) {
+ dispatch(rasterLayerAdjustmentsCancel({ entityIdentifier }));
+ } else {
+ dispatch(
+ rasterLayerAdjustmentsSet({
+ entityIdentifier,
+ adjustments: makeDefaultRasterLayerAdjustments('simple'),
+ })
+ );
+ }
+ }, [dispatch, entityIdentifier, hasAdjustments]);
+
+ return (
+ }>
+ {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')}
+
+ );
+});
+
+RasterLayerMenuItemsAdjustments.displayName = 'RasterLayerMenuItemsAdjustments';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx
new file mode 100644
index 00000000000..42c45e1c36d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx
@@ -0,0 +1,118 @@
+import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
+import { createSelector } from '@reduxjs/toolkit';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { rasterLayerAdjustmentsSimpleUpdated } from 'features/controlLayers/store/canvasSlice';
+import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
+import type { SimpleAdjustmentsConfig } from 'features/controlLayers/store/types';
+import React, { memo, useCallback, useMemo } from 'react';
+import { useTranslation } from 'react-i18next';
+
+type AdjustmentSliderRowProps = {
+ label: string;
+ name: keyof SimpleAdjustmentsConfig;
+ onChange: (v: number) => void;
+ min?: number;
+ max?: number;
+ step?: number;
+};
+
+const AdjustmentSliderRow = ({ label, name, onChange, min = -1, max = 1, step = 0.01 }: AdjustmentSliderRowProps) => {
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const selectValue = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) =>
+ selectEntity(canvas, entityIdentifier)?.adjustments?.simple?.[name] ?? DEFAULT_SIMPLE_ADJUSTMENTS[name]
+ );
+ }, [entityIdentifier, name]);
+ const value = useAppSelector(selectValue);
+
+ return (
+
+
+ {label}
+
+
+
+
+ );
+};
+
+const DEFAULT_SIMPLE_ADJUSTMENTS = {
+ brightness: 0,
+ contrast: 0,
+ saturation: 0,
+ temperature: 0,
+ tint: 0,
+ sharpness: 0,
+};
+
+export const RasterLayerSimpleAdjustmentsEditor = memo(() => {
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const { t } = useTranslation();
+ const selectIsDisabled = useMemo(() => {
+ return createSelector(
+ selectCanvasSlice,
+ (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled !== true
+ );
+ }, [entityIdentifier]);
+ const isDisabled = useAppSelector(selectIsDisabled);
+
+ const onBrightness = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { brightness: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onContrast = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { contrast: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onSaturation = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { saturation: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onTemperature = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { temperature: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onTint = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { tint: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onSharpness = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { sharpness: v } })),
+ [dispatch, entityIdentifier]
+ );
+
+ return (
+
+
+
+
+
+
+
+
+ );
+});
+
+RasterLayerSimpleAdjustmentsEditor.displayName = 'RasterLayerSimpleAdjustmentsEditor';
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
index 6c55e949377..2b45f61b291 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
@@ -475,7 +475,7 @@ export abstract class CanvasEntityAdapterBase {
this.log.trace({ rect }, 'Getting canvas');
// The opacity may have been changed in response to user selecting a different entity category, so we must restore
// the original opacity before rendering the canvas
- const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] };
+ const attrs: GroupConfig = { opacity: this.state.opacity };
const canvas = this.renderer.getCanvas({ rect, attrs });
return canvas;
};
@@ -74,4 +80,79 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked'];
return omit(this.state, keysToOmit);
};
+
+ private syncAdjustmentsFilter = () => {
+ const a = this.state.adjustments;
+ const apply = !!a && a.enabled;
+ // The filter operates on the renderer's object group; we can set filters at the group level via renderer
+ const group = this.renderer.konva.objectGroup;
+ if (apply) {
+ const filters = group.filters() ?? [];
+ let nextFilters = filters.filter((f) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter);
+ if (a.mode === 'simple') {
+ group.setAttr('adjustmentsSimple', a.simple);
+ group.setAttr('adjustmentsCurves', null);
+ nextFilters = [...nextFilters, AdjustmentsSimpleFilter];
+ } else {
+ // Build LUTs and set curves attr
+ const master = buildCurveLUT(a.curves.master);
+ const r = buildCurveLUT(a.curves.r);
+ const g = buildCurveLUT(a.curves.g);
+ const b = buildCurveLUT(a.curves.b);
+ group.setAttr('adjustmentsCurves', { master, r, g, b });
+ group.setAttr('adjustmentsSimple', null);
+ nextFilters = [...nextFilters, AdjustmentsCurvesFilter];
+ }
+ group.filters(nextFilters);
+ this._throttledCacheRefresh();
+ } else {
+ // Remove our filter if present
+ const filters = (group.filters() ?? []).filter(
+ (f) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter
+ );
+ group.filters(filters);
+ group.setAttr('adjustmentsSimple', null);
+ group.setAttr('adjustmentsCurves', null);
+ this._throttledCacheRefresh();
+ }
+ };
+
+ private _throttledCacheRefresh = throttle(() => this.renderer.syncKonvaCache(true), 50);
+
+ private haveAdjustmentsChanged = (prevState: CanvasRasterLayerState, currState: CanvasRasterLayerState): boolean => {
+ const pa = prevState.adjustments;
+ const ca = currState.adjustments;
+ if (pa === ca) {
+ return false;
+ }
+ if (!pa || !ca) {
+ return true;
+ }
+ if (pa.enabled !== ca.enabled) {
+ return true;
+ }
+ if (pa.mode !== ca.mode) {
+ return true;
+ }
+ // simple params
+ const ps = pa.simple;
+ const cs = ca.simple;
+ if (
+ ps.brightness !== cs.brightness ||
+ ps.contrast !== cs.contrast ||
+ ps.saturation !== cs.saturation ||
+ ps.temperature !== cs.temperature ||
+ ps.tint !== cs.tint ||
+ ps.sharpness !== cs.sharpness
+ ) {
+ return true;
+ }
+ // curves reference (UI not implemented yet) - if arrays differ by ref, consider changed
+ const pc = pa.curves;
+ const cc = ca.curves;
+ if (pc !== cc) {
+ return true;
+ }
+ return false;
+ };
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
index 34c5c9ac5de..044e6c52b18 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
@@ -3,6 +3,10 @@
* https://konvajs.org/docs/filters/Custom_Filter.html
*/
+import { clamp } from 'es-toolkit/compat';
+import { zCurvesAdjustmentsLUTs, zSimpleAdjustmentsConfig } from 'features/controlLayers/store/types';
+import type Konva from 'konva';
+
/**
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
* This is useful for edge maps and other masks, to make the black areas transparent.
@@ -20,3 +24,177 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => {
imageData.data[i * 4 + 3] = Math.min(a, (cMin + cMax) / 2);
}
};
+
+/**
+ * Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness).
+ *
+ * Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter.
+ */
+export const AdjustmentsSimpleFilter = function (this: Konva.Node, imageData: ImageData): void {
+ const paramsRaw = this.getAttr('adjustmentsSimple');
+ const parseResult = zSimpleAdjustmentsConfig.safeParse(paramsRaw);
+ if (!parseResult.success) {
+ return;
+ }
+ const params = parseResult.data;
+
+ const { brightness, contrast, saturation, temperature, tint, sharpness } = params;
+
+ const data = imageData.data;
+ const len = data.length / 4;
+ const width = imageData.width;
+ const height = imageData.height;
+
+ // Precompute factors
+ const brightnessShift = brightness * 255; // additive shift
+ const contrastFactor = 1 + contrast; // scale around 128
+
+ // Temperature/Tint multipliers
+ const tempK = 0.5;
+ const tintK = 0.5;
+ const rTempMul = 1 + temperature * tempK;
+ const bTempMul = 1 - temperature * tempK;
+ // Tint: green <-> magenta. Positive = magenta (R/B up, G down). Negative = green (G up, R/B down).
+ const t = clamp(tint, -1, 1) * tintK;
+ const mag = Math.abs(t);
+ const rTintMul = t >= 0 ? 1 + mag : 1 - mag;
+ const gTintMul = t >= 0 ? 1 - mag : 1 + mag;
+ const bTintMul = t >= 0 ? 1 + mag : 1 - mag;
+
+ // Saturation matrix (HSL-based approximation via luma coefficients)
+ const lumaR = 0.2126;
+ const lumaG = 0.7152;
+ const lumaB = 0.0722;
+ const S = 1 + saturation; // 0..2
+ const m00 = lumaR * (1 - S) + S;
+ const m01 = lumaG * (1 - S);
+ const m02 = lumaB * (1 - S);
+ const m10 = lumaR * (1 - S);
+ const m11 = lumaG * (1 - S) + S;
+ const m12 = lumaB * (1 - S);
+ const m20 = lumaR * (1 - S);
+ const m21 = lumaG * (1 - S);
+ const m22 = lumaB * (1 - S) + S;
+
+ // First pass: apply per-pixel color adjustments (excluding sharpness)
+ for (let i = 0; i < len; i++) {
+ const idx = i * 4;
+ let r = data[idx + 0] as number;
+ let g = data[idx + 1] as number;
+ let b = data[idx + 2] as number;
+ const a = data[idx + 3] as number;
+
+ // Brightness (additive)
+ r = r + brightnessShift;
+ g = g + brightnessShift;
+ b = b + brightnessShift;
+
+ // Contrast around mid-point 128
+ r = (r - 128) * contrastFactor + 128;
+ g = (g - 128) * contrastFactor + 128;
+ b = (b - 128) * contrastFactor + 128;
+
+ // Temperature (R/B axis) and Tint (G vs Magenta)
+ r = r * rTempMul * rTintMul;
+ g = g * gTintMul;
+ b = b * bTempMul * bTintMul;
+
+ // Saturation via matrix
+ const r2 = r * m00 + g * m01 + b * m02;
+ const g2 = r * m10 + g * m11 + b * m12;
+ const b2 = r * m20 + g * m21 + b * m22;
+
+ data[idx + 0] = clamp(r2, 0, 255);
+ data[idx + 1] = clamp(g2, 0, 255);
+ data[idx + 2] = clamp(b2, 0, 255);
+ data[idx + 3] = a;
+ }
+
+ // Optional sharpen (simple unsharp mask with 3x3 kernel)
+ if (Math.abs(sharpness) > 1e-3 && width > 2 && height > 2) {
+ const src = new Uint8ClampedArray(data); // copy of modified data
+ const a = Math.max(-1, Math.min(1, sharpness)) * 0.5; // amount
+ const center = 1 + 4 * a;
+ const neighbor = -a;
+ for (let y = 1; y < height - 1; y++) {
+ for (let x = 1; x < width - 1; x++) {
+ const idx = (y * width + x) * 4;
+ for (let c = 0; c < 3; c++) {
+ const centerPx = src[idx + c] ?? 0;
+ const leftPx = src[idx - 4 + c] ?? 0;
+ const rightPx = src[idx + 4 + c] ?? 0;
+ const topPx = src[idx - width * 4 + c] ?? 0;
+ const bottomPx = src[idx + width * 4 + c] ?? 0;
+ const v = centerPx * center + leftPx * neighbor + rightPx * neighbor + topPx * neighbor + bottomPx * neighbor;
+ data[idx + c] = clamp(v, 0, 255);
+ }
+ // preserve alpha
+ }
+ }
+ }
+};
+
+// Build a 256-length LUT from 0..255 control points (linear interpolation for v1)
+export const buildCurveLUT = (points: Array<[number, number]>): number[] => {
+ if (!points || points.length === 0) {
+ return Array.from({ length: 256 }, (_, i) => i);
+ }
+ const pts = points
+ .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number])
+ .sort((a, b) => a[0] - b[0]);
+ if ((pts[0]?.[0] ?? 0) !== 0) {
+ pts.unshift([0, pts[0]?.[1] ?? 0]);
+ }
+ const last = pts[pts.length - 1];
+ if ((last?.[0] ?? 255) !== 255) {
+ pts.push([255, last?.[1] ?? 255]);
+ }
+ const lut = new Array(256);
+ let j = 0;
+ for (let x = 0; x <= 255; x++) {
+ while (j < pts.length - 2 && x > (pts[j + 1]?.[0] ?? 255)) {
+ j++;
+ }
+ const p0 = pts[j] ?? [0, 0];
+ const p1 = pts[j + 1] ?? [255, 255];
+ const [x0, y0] = p0;
+ const [x1, y1] = p1;
+ const t = x1 === x0 ? 0 : (x - x0) / (x1 - x0);
+ const y = y0 + (y1 - y0) * t;
+ lut[x] = clamp(Math.round(y), 0, 255);
+ }
+ return lut;
+};
+
+/**
+ * Per-layer curves adjustments filter (master, r, g, b)
+ *
+ * Parameters are read from the Konva node attr `adjustmentsCurves` set by the adapter.
+ */
+export const AdjustmentsCurvesFilter = function (this: Konva.Node, imageData: ImageData): void {
+ const paramsRaw = this.getAttr('adjustmentsCurves');
+ const parseResult = zCurvesAdjustmentsLUTs.safeParse(paramsRaw);
+ if (!parseResult.success) {
+ return;
+ }
+ const params = parseResult.data;
+
+ const { master, r, g, b } = params;
+ if (!master || !r || !g || !b) {
+ return;
+ }
+ const data = imageData.data;
+ const len = data.length / 4;
+ for (let i = 0; i < len; i++) {
+ const idx = i * 4;
+ const r0 = data[idx + 0] as number;
+ const g0 = data[idx + 1] as number;
+ const b0 = data[idx + 2] as number;
+ const rm = master[r0] ?? r0;
+ const gm = master[g0] ?? g0;
+ const bm = master[b0] ?? b0;
+ data[idx + 0] = clamp(r[rm] ?? rm, 0, 255);
+ data[idx + 1] = clamp(g[gm] ?? gm, 0, 255);
+ data[idx + 2] = clamp(b[bm] ?? bm, 0, 255);
+ }
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
index 46f1e620bdd..8c8987d5462 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
@@ -19,12 +19,16 @@ import type {
CanvasEntityType,
CanvasInpaintMaskState,
CanvasMetadata,
+ ChannelName,
+ ChannelPoints,
ControlLoRAConfig,
EntityMovedByPayload,
FillStyle,
FLUXReduxImageInfluence,
+ RasterLayerAdjustments,
RegionalGuidanceRefImageState,
RgbColor,
+ SimpleAdjustmentsConfig,
} from 'features/controlLayers/store/types';
import {
calculateNewSize,
@@ -96,6 +100,7 @@ import {
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
+ makeDefaultRasterLayerAdjustments,
} from './util';
const slice = createSlice({
@@ -104,6 +109,96 @@ const slice = createSlice({
reducers: {
// undoable canvas state
//#region Raster layers
+ rasterLayerAdjustmentsSet: (
+ state,
+ action: PayloadAction>
+ ) => {
+ const { entityIdentifier, adjustments } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ if (adjustments === null) {
+ delete layer.adjustments;
+ return;
+ }
+ if (!layer.adjustments) {
+ layer.adjustments = makeDefaultRasterLayerAdjustments(adjustments.mode ?? 'simple');
+ }
+ layer.adjustments = merge(layer.adjustments, adjustments);
+ },
+ rasterLayerAdjustmentsReset: (state, action: PayloadAction>) => {
+ const { entityIdentifier } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.simple = makeDefaultRasterLayerAdjustments('simple').simple;
+ layer.adjustments.curves = makeDefaultRasterLayerAdjustments('curves').curves;
+ },
+ rasterLayerAdjustmentsCancel: (state, action: PayloadAction>) => {
+ const { entityIdentifier } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ delete layer.adjustments;
+ },
+ rasterLayerAdjustmentsModeChanged: (
+ state,
+ action: PayloadAction>
+ ) => {
+ const { entityIdentifier, mode } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.mode = mode;
+ },
+ rasterLayerAdjustmentsSimpleUpdated: (
+ state,
+ action: PayloadAction }, 'raster_layer'>>
+ ) => {
+ const { entityIdentifier, simple } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.simple = merge(layer.adjustments.simple, simple);
+ },
+ rasterLayerAdjustmentsCurvesUpdated: (
+ state,
+ action: PayloadAction>
+ ) => {
+ const { entityIdentifier, channel, points } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.curves[channel] = points;
+ },
+ rasterLayerAdjustmentsEnabledToggled: (
+ state,
+ action: PayloadAction>
+ ) => {
+ const { entityIdentifier } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.enabled = !layer.adjustments.enabled;
+ },
+ rasterLayerAdjustmentsCollapsedToggled: (
+ state,
+ action: PayloadAction>
+ ) => {
+ const { entityIdentifier } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer?.adjustments) {
+ return;
+ }
+ layer.adjustments.collapsed = !layer.adjustments.collapsed;
+ },
rasterLayerAdded: {
reducer: (
state,
@@ -1658,6 +1753,15 @@ export const {
entityBrushLineAdded,
entityEraserLineAdded,
entityRectAdded,
+ // Raster layer adjustments
+ rasterLayerAdjustmentsSet,
+ rasterLayerAdjustmentsCancel,
+ rasterLayerAdjustmentsReset,
+ rasterLayerAdjustmentsModeChanged,
+ rasterLayerAdjustmentsEnabledToggled,
+ rasterLayerAdjustmentsCollapsedToggled,
+ rasterLayerAdjustmentsSimpleUpdated,
+ rasterLayerAdjustmentsCurvesUpdated,
entityDeleted,
entityArrangedForwardOne,
entityArrangedToFront,
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
index d0a3414a572..d5012bd88c9 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
@@ -378,11 +378,57 @@ const zControlLoRAConfig = z.object({
});
export type ControlLoRAConfig = z.infer;
+/**
+ * All simple params normalized to `[-1, 1]` except sharpness `[0, 1]`.
+ *
+ * - Brightness: -1 (darken) to 1 (brighten)
+ * - Contrast: -1 (decrease contrast) to 1 (increase contrast)
+ * - Saturation: -1 (desaturate) to 1 (saturate)
+ * - Temperature: -1 (cooler/blue) to 1 (warmer/yellow)
+ * - Tint: -1 (greener) to 1 (more magenta)
+ * - Sharpness: 0 (no sharpening) to 1 (maximum sharpening)
+ */
+export const zSimpleAdjustmentsConfig = z.object({
+ brightness: z.number().gte(-1).lte(1),
+ contrast: z.number().gte(-1).lte(1),
+ saturation: z.number().gte(-1).lte(1),
+ temperature: z.number().gte(-1).lte(1),
+ tint: z.number().gte(-1).lte(1),
+ sharpness: z.number().gte(0).lte(1),
+});
+export type SimpleAdjustmentsConfig = z.infer;
+
+const zUint8 = z.number().int().min(0).max(255);
+const zChannelPoints = z.array(z.tuple([zUint8, zUint8])).min(2);
+const zChannelName = z.enum(['master', 'r', 'g', 'b']);
+const zCurvesAdjustmentsConfig = z.record(zChannelName, zChannelPoints);
+export type ChannelName = z.infer;
+export type ChannelPoints = z.infer;
+export type CurvesAdjustmentsConfig = z.infer;
+
+/**
+ * The curves adjustments are stored as LUTs in the Konva node attributes. Konva will use these values when applying
+ * the filter.
+ */
+export const zCurvesAdjustmentsLUTs = z.record(zChannelName, z.array(zUint8));
+
+const zRasterLayerAdjustments = z.object({
+ version: z.literal(1),
+ enabled: z.boolean(),
+ collapsed: z.boolean(),
+ mode: z.enum(['simple', 'curves']),
+ simple: zSimpleAdjustmentsConfig,
+ curves: zCurvesAdjustmentsConfig,
+});
+export type RasterLayerAdjustments = z.infer;
+
const zCanvasRasterLayerState = zCanvasEntityBase.extend({
type: z.literal('raster_layer'),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
+ // Optional per-layer color adjustments (simple + curves). When undefined, no adjustments are applied.
+ adjustments: zRasterLayerAdjustments.optional(),
});
export type CanvasRasterLayerState = z.infer;
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
index 88798a048df..cb6e816e320 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
@@ -15,6 +15,7 @@ import type {
Gemini2_5ReferenceImageConfig,
ImageWithDims,
IPAdapterConfig,
+ RasterLayerAdjustments,
RefImageState,
RgbColor,
T2IAdapterConfig,
@@ -118,6 +119,32 @@ export const initialControlLoRA: ControlLoRAConfig = {
weight: 0.75,
};
+export const makeDefaultRasterLayerAdjustments = (mode: 'simple' | 'curves' = 'simple'): RasterLayerAdjustments => ({
+ version: 1,
+ enabled: true,
+ collapsed: false,
+ mode,
+ simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 },
+ curves: {
+ master: [
+ [0, 0],
+ [255, 255],
+ ],
+ r: [
+ [0, 0],
+ [255, 255],
+ ],
+ g: [
+ [0, 0],
+ [255, 255],
+ ],
+ b: [
+ [0, 0],
+ [255, 255],
+ ],
+ },
+});
+
export const getReferenceImageState = (id: string, overrides?: PartialDeep): RefImageState => {
const entityState: RefImageState = {
id,
@@ -187,6 +214,7 @@ export const getRasterLayerState = (
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
+ adjustments: undefined,
};
merge(entityState, overrides);
return entityState;