diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 43530efca28..e2739d181a6 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2083,6 +2083,24 @@ "pullBboxIntoLayerError": "Problem Pulling BBox Into Layer", "pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage", "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", + "addAdjustments": "Add Adjustments", + "removeAdjustments": "Remove Adjustments", + "adjustments": { + "simple": "Simple", + "curves": "Curves", + "heading": "Adjustments", + "expand": "Expand adjustments", + "collapse": "Collapse adjustments", + "brightness": "Brightness", + "contrast": "Contrast", + "saturation": "Saturation", + "temperature": "Temperature", + "tint": "Tint", + "sharpness": "Sharpness", + "finish": "Finish", + "reset": "Reset", + "master": "Master" + }, "regionIsEmpty": "Selected region is empty", "mergeVisible": "Merge Visible", "mergeDown": "Merge Down", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx index ddaefb1073e..13dc30dea20 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx @@ -4,6 +4,7 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions'; import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage'; import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit'; +import { RasterLayerAdjustmentsPanel } from 'features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel'; import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate'; import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext'; import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; @@ -39,6 +40,7 @@ export const RasterLayer = memo(({ id }: Props) => { + { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const canvasManager = useCanvasManager(); + + const selectHasAdjustments = useMemo(() => { + return createSelector(selectCanvasSlice, (canvas) => Boolean(selectEntity(canvas, entityIdentifier)?.adjustments)); + }, [entityIdentifier]); + + const hasAdjustments = useAppSelector(selectHasAdjustments); + + const selectMode = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.mode ?? 'simple' + ); + }, [entityIdentifier]); + const mode = useAppSelector(selectMode); + + const selectEnabled = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled ?? false + ); + }, [entityIdentifier]); + const enabled = useAppSelector(selectEnabled); + + const selectCollapsed = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.collapsed ?? false + ); + }, [entityIdentifier]); + const collapsed = useAppSelector(selectCollapsed); + + const onToggleEnabled = useCallback(() => { + dispatch(rasterLayerAdjustmentsEnabledToggled({ entityIdentifier })); + }, [dispatch, entityIdentifier]); + + const onReset = useCallback(() => { + // Reset values to defaults but keep adjustments present; preserve enabled/collapsed/mode + dispatch(rasterLayerAdjustmentsReset({ entityIdentifier })); + }, [dispatch, entityIdentifier]); + + const onCancel = useCallback(() => { + // Clear out adjustments entirely + dispatch(rasterLayerAdjustmentsCancel({ entityIdentifier })); + }, [dispatch, entityIdentifier]); + + const onToggleCollapsed = useCallback(() => { + dispatch(rasterLayerAdjustmentsCollapsedToggled({ entityIdentifier })); + }, [dispatch, entityIdentifier]); + + const onClickModeSimple = useCallback( + () => dispatch(rasterLayerAdjustmentsModeChanged({ entityIdentifier, mode: 'simple' })), + [dispatch, entityIdentifier] + ); + + const onClickModeCurves = useCallback( + () => dispatch(rasterLayerAdjustmentsModeChanged({ entityIdentifier, mode: 'curves' })), + [dispatch, entityIdentifier] + ); + + const onFinish = useCallback(async () => { + // Bake current visual into layer pixels, then clear adjustments + const adapter = canvasManager.getAdapter(entityIdentifier); + if (!adapter || adapter.type !== 'raster_layer_adapter') { + return; + } + const rect = adapter.transformer.getRelativeRect(); + try { + await adapter.renderer.rasterize({ rect, replaceObjects: true }); + // Clear adjustments after baking + dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null })); + } catch { + // no-op; leave state unchanged on failure + } + }, [canvasManager, entityIdentifier, dispatch]); + + // Hide the panel entirely until adjustments are added via context menu + if (!hasAdjustments) { + return null; + } + + return ( + <> + + + } + /> + + Adjustments + + + + + + + } + variant="ghost" + /> + } + variant="ghost" + /> + } + variant="ghost" + /> + + + {!collapsed && mode === 'simple' && } + + {!collapsed && mode === 'curves' && } + + ); +}); + +RasterLayerAdjustmentsPanel.displayName = 'RasterLayerAdjustmentsPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx new file mode 100644 index 00000000000..9610927e016 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsEditor.tsx @@ -0,0 +1,179 @@ +import { Box, Flex } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice'; +import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors'; +import type { ChannelName, ChannelPoints, CurvesAdjustmentsConfig } from 'features/controlLayers/store/types'; +import { memo, useCallback, useEffect, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { RasterLayerCurvesAdjustmentsGraph } from './RasterLayerCurvesAdjustmentsGraph'; + +const DEFAULT_POINTS: ChannelPoints = [ + [0, 0], + [255, 255], +]; + +const DEFAULT_CURVES: CurvesAdjustmentsConfig = { + master: DEFAULT_POINTS, + r: DEFAULT_POINTS, + g: DEFAULT_POINTS, + b: DEFAULT_POINTS, +}; + +type ChannelHistograms = Record; + +const calculateHistogramsFromImageData = (imageData: ImageData): ChannelHistograms | null => { + try { + const data = imageData.data; + const len = data.length / 4; + const master = new Array(256).fill(0); + const r = new Array(256).fill(0); + const g = new Array(256).fill(0); + const b = new Array(256).fill(0); + // sample every 4th pixel to lighten work + for (let i = 0; i < len; i += 4) { + const idx = i * 4; + const rv = data[idx] as number; + const gv = data[idx + 1] as number; + const bv = data[idx + 2] as number; + const m = Math.round(0.2126 * rv + 0.7152 * gv + 0.0722 * bv); + if (m >= 0 && m < 256) { + master[m] = (master[m] ?? 0) + 1; + } + if (rv >= 0 && rv < 256) { + r[rv] = (r[rv] ?? 0) + 1; + } + if (gv >= 0 && gv < 256) { + g[gv] = (g[gv] ?? 0) + 1; + } + if (bv >= 0 && bv < 256) { + b[bv] = (b[bv] ?? 0) + 1; + } + } + return { + master, + r, + g, + b, + }; + } catch { + return null; + } +}; + +export const RasterLayerCurvesAdjustmentsEditor = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer'); + const { t } = useTranslation(); + const selectCurves = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.curves ?? DEFAULT_CURVES + ); + }, [entityIdentifier]); + const curves = useAppSelector(selectCurves); + + const selectIsDisabled = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled !== true + ); + }, [entityIdentifier]); + const isDisabled = useAppSelector(selectIsDisabled); + // The canvas cache for the layer serves as a proxy for when the layer changes and can be used to trigger histo recalc + const canvasCache = useStore(adapter.$canvasCache); + + const [histMaster, setHistMaster] = useState(null); + const [histR, setHistR] = useState(null); + const [histG, setHistG] = useState(null); + const [histB, setHistB] = useState(null); + + const recalcHistogram = useCallback(() => { + try { + const rect = adapter.transformer.getRelativeRect(); + if (rect.width === 0 || rect.height === 0) { + setHistMaster(Array(256).fill(0)); + setHistR(Array(256).fill(0)); + setHistG(Array(256).fill(0)); + setHistB(Array(256).fill(0)); + return; + } + const imageData = adapter.renderer.getImageData({ rect }); + const h = calculateHistogramsFromImageData(imageData); + if (h) { + setHistMaster(h.master); + setHistR(h.r); + setHistG(h.g); + setHistB(h.b); + } + } catch { + // ignore + } + }, [adapter]); + + useEffect(() => { + recalcHistogram(); + }, [canvasCache, recalcHistogram]); + + const onChangePoints = useCallback( + (channel: ChannelName, pts: ChannelPoints) => { + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts })); + }, + [dispatch, entityIdentifier] + ); + + // Memoize per-channel change handlers to avoid inline lambdas in JSX + const onChangeMaster = useCallback((pts: ChannelPoints) => onChangePoints('master', pts), [onChangePoints]); + const onChangeR = useCallback((pts: ChannelPoints) => onChangePoints('r', pts), [onChangePoints]); + const onChangeG = useCallback((pts: ChannelPoints) => onChangePoints('g', pts), [onChangePoints]); + const onChangeB = useCallback((pts: ChannelPoints) => onChangePoints('b', pts), [onChangePoints]); + + return ( + + + + + + + + + ); +}); + +RasterLayerCurvesAdjustmentsEditor.displayName = 'RasterLayerCurvesEditor'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx new file mode 100644 index 00000000000..d8166ef686f --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesAdjustmentsGraph.tsx @@ -0,0 +1,432 @@ +import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; +import type { ChannelName, ChannelPoints } from 'features/controlLayers/store/types'; +import React, { memo, useCallback, useEffect, useRef, useState } from 'react'; +import { PiArrowCounterClockwiseBold } from 'react-icons/pi'; + +const DEFAULT_POINTS: ChannelPoints = [ + [0, 0], + [255, 255], +]; + +const channelColor: Record = { + master: '#888', + r: '#e53e3e', + g: '#38a169', + b: '#3182ce', +}; + +const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v); + +const sortPoints = (pts: ChannelPoints) => + [...pts] + .sort((a, b) => { + const xDiff = a[0] - b[0]; + if (xDiff) { + return xDiff; + } + if (a[0] === 0 || a[0] === 255) { + return a[1] - b[1]; + } + return 0; + }) + // Finally, clamp to valid range and round to integers + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] satisfies [number, number]); + +// Base canvas logical coordinate system (used for aspect ratio & initial sizing) +const CANVAS_WIDTH = 256; +const CANVAS_HEIGHT = 160; +const MARGIN_LEFT = 8; +const MARGIN_RIGHT = 8; +const MARGIN_TOP = 8; +const MARGIN_BOTTOM = 10; + +const CANVAS_STYLE: React.CSSProperties = { + width: '100%', + // Maintain aspect ratio while allowing responsive width. Height is set automatically via aspect-ratio. + aspectRatio: `${CANVAS_WIDTH} / ${CANVAS_HEIGHT}`, + height: 'auto', + touchAction: 'none', + borderRadius: 4, + background: '#111', + display: 'block', +}; + +type CurveGraphProps = { + title: string; + channel: ChannelName; + points: ChannelPoints | undefined; + histogram: number[] | null; + onChange: (pts: ChannelPoints) => void; +}; + +const drawHistogram = ( + c: HTMLCanvasElement, + channel: ChannelName, + histogram: number[] | null, + points: ChannelPoints +) => { + // Use device pixel ratio for crisp rendering on HiDPI displays. + const dpr = window.devicePixelRatio || 1; + const cssWidth = c.clientWidth || CANVAS_WIDTH; // CSS pixels + const cssHeight = (cssWidth * CANVAS_HEIGHT) / CANVAS_WIDTH; // maintain aspect ratio + + // Ensure the backing store matches current display size * dpr (only if changed). + const targetWidth = Math.round(cssWidth * dpr); + const targetHeight = Math.round(cssHeight * dpr); + if (c.width !== targetWidth || c.height !== targetHeight) { + c.width = targetWidth; + c.height = targetHeight; + } + // Guarantee the CSS height stays synced (width is 100%). + if (c.style.height !== `${cssHeight}px`) { + c.style.height = `${cssHeight}px`; + } + + const ctx = c.getContext('2d'); + if (!ctx) { + return; + } + + // Reset transform then scale for dpr so we can draw in CSS pixel coordinates. + ctx.setTransform(1, 0, 0, 1, 0, 0); + ctx.scale(dpr, dpr); + + // Dynamic inner geometry (CSS pixel space) + const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT; + const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM; + + const valueToCanvasX = (x: number) => MARGIN_LEFT + (clamp(x, 0, 255) / 255) * innerWidth; + const valueToCanvasY = (y: number) => MARGIN_TOP + innerHeight - (clamp(y, 0, 255) / 255) * innerHeight; + + // Clear & background + ctx.clearRect(0, 0, cssWidth, cssHeight); + ctx.fillStyle = '#111'; + ctx.fillRect(0, 0, cssWidth, cssHeight); + + // Grid + ctx.strokeStyle = '#2a2a2a'; + ctx.lineWidth = 1; + for (let i = 0; i <= 4; i++) { + const y = MARGIN_TOP + (i * innerHeight) / 4; + ctx.beginPath(); + ctx.moveTo(MARGIN_LEFT + 0.5, y + 0.5); + ctx.lineTo(MARGIN_LEFT + innerWidth - 0.5, y + 0.5); + ctx.stroke(); + } + for (let i = 0; i <= 4; i++) { + const x = MARGIN_LEFT + (i * innerWidth) / 4; + ctx.beginPath(); + ctx.moveTo(x + 0.5, MARGIN_TOP + 0.5); + ctx.lineTo(x + 0.5, MARGIN_TOP + innerHeight - 0.5); + ctx.stroke(); + } + + // Histogram + if (histogram) { + const logHist = histogram.map((v) => Math.log10((v ?? 0) + 1)); + const max = Math.max(1e-6, ...logHist); + ctx.fillStyle = '#5557'; + + // If there's enough horizontal room, draw each of the 256 bins with exact (possibly fractional) width so they tessellate. + // Otherwise, aggregate multiple bins into per-pixel columns to avoid aliasing. + if (innerWidth >= 256) { + for (let i = 0; i < 256; i++) { + const v = logHist[i] ?? 0; + const h = (v / max) * (innerHeight - 2); + // Exact fractional coordinates for seamless coverage (no gaps as width grows) + const x0 = MARGIN_LEFT + (i / 256) * innerWidth; + const x1 = MARGIN_LEFT + ((i + 1) / 256) * innerWidth; + const w = x1 - x0; + if (w <= 0) { + continue; + } // safety + const y = MARGIN_TOP + innerHeight - h; + ctx.fillRect(x0, y, w, h); + } + } else { + // Aggregate bins per CSS pixel column (similar to previous anti-moire approach) + const columns = Math.max(1, Math.round(innerWidth)); + const binsPerCol = 256 / columns; + for (let col = 0; col < columns; col++) { + const startBin = Math.floor(col * binsPerCol); + const endBin = Math.min(255, Math.floor((col + 1) * binsPerCol - 1)); + let acc = 0; + let count = 0; + for (let b = startBin; b <= endBin; b++) { + acc += logHist[b] ?? 0; + count++; + } + const v = count > 0 ? acc / count : 0; + const h = (v / max) * (innerHeight - 2); + const x = MARGIN_LEFT + col; + const y = MARGIN_TOP + innerHeight - h; + ctx.fillRect(x, y, 1, h); + } + } + } + + // Curve + const pts = sortPoints(points); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 2; + ctx.beginPath(); + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); + if (i === 0) { + ctx.moveTo(cx, cy); + } else { + ctx.lineTo(cx, cy); + } + } + ctx.stroke(); + + // Control points + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); + ctx.fillStyle = '#000'; + ctx.beginPath(); + ctx.arc(cx, cy, 3.5, 0, Math.PI * 2); + ctx.fill(); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 1.5; + ctx.stroke(); + } +}; + +const getNearestPointIndex = (c: HTMLCanvasElement, points: ChannelPoints, mx: number, my: number) => { + const cssWidth = c.clientWidth || CANVAS_WIDTH; + const cssHeight = c.clientHeight || CANVAS_HEIGHT; + const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT; + const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM; + const canvasToValueX = (cx: number) => clamp(Math.round(((cx - MARGIN_LEFT) / innerWidth) * 255), 0, 255); + const canvasToValueY = (cy: number) => clamp(Math.round(255 - ((cy - MARGIN_TOP) / innerHeight) * 255), 0, 255); + const xVal = canvasToValueX(mx); + const yVal = canvasToValueY(my); + let best = -1; + let bestDist = 9999; + for (let i = 0; i < points.length; i++) { + const [px, py] = points[i]!; + const dx = px - xVal; + const dy = py - yVal; + const d = dx * dx + dy * dy; + if (d < bestDist) { + best = i; + bestDist = d; + } + } + if (best !== -1 && bestDist <= 20 * 20) { + return best; + } + return -1; +}; + +const canvasXToValueX = (c: HTMLCanvasElement, cx: number): number => { + const cssWidth = c.clientWidth || CANVAS_WIDTH; + const innerWidth = cssWidth - MARGIN_LEFT - MARGIN_RIGHT; + return clamp(Math.round(((cx - MARGIN_LEFT) / innerWidth) * 255), 0, 255); +}; + +const canvasYToValueY = (c: HTMLCanvasElement, cy: number) => { + const cssHeight = c.clientHeight || CANVAS_HEIGHT; + const innerHeight = cssHeight - MARGIN_TOP - MARGIN_BOTTOM; + return clamp(Math.round(255 - ((cy - MARGIN_TOP) / innerHeight) * 255), 0, 255); +}; + +export const RasterLayerCurvesAdjustmentsGraph = memo((props: CurveGraphProps) => { + const { title, channel, points, histogram, onChange } = props; + const canvasRef = useRef(null); + const [localPoints, setLocalPoints] = useState(sortPoints(points ?? DEFAULT_POINTS)); + const [dragIndex, setDragIndex] = useState(null); + + useEffect(() => { + setLocalPoints(sortPoints(points ?? DEFAULT_POINTS)); + }, [points]); + + useEffect(() => { + const c = canvasRef.current; + if (!c) { + return; + } + drawHistogram(c, channel, histogram, localPoints); + }, [channel, histogram, localPoints]); + + const handlePointerDown = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + const c = canvasRef.current; + if (!c) { + return; + } + // Capture the pointer so we still get pointerup even if released outside the canvas. + try { + c.setPointerCapture(e.pointerId); + } catch { + /* ignore */ + } + const rect = c.getBoundingClientRect(); + const mx = e.clientX - rect.left; // CSS pixel coordinates + const my = e.clientY - rect.top; + const idx = getNearestPointIndex(c, localPoints, mx, my); + if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) { + setDragIndex(idx); + return; + } + const xVal = canvasXToValueX(c, mx); + const yVal = canvasYToValueY(c, my); + const next = sortPoints([...localPoints, [xVal, yVal]]); + setLocalPoints(next); + setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal)); + }, + [localPoints] + ); + + const handlePointerMove = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + if (dragIndex === null) { + return; + } + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const mx = e.clientX - rect.left; + const my = e.clientY - rect.top; + const mxVal = canvasXToValueX(c, mx); + const myVal = canvasYToValueY(c, my); + setLocalPoints((prev) => { + // Endpoints are immutable; safety check. + if (dragIndex === 0 || dragIndex === prev.length - 1) { + return prev; + } + const leftX = prev[dragIndex - 1]![0]; + const rightX = prev[dragIndex + 1]![0]; + // Constrain to strictly between neighbors so ordering is preserved & no crossing. + const minX = Math.min(254, leftX); + const maxX = Math.max(1, rightX); + const clampedX = clamp(mxVal, minX, maxX); + // If neighbors are adjacent (minX > maxX after adjustments), effectively lock X. + const finalX = minX > maxX ? leftX + 1 - 1 /* keep existing */ : clampedX; + const next = [...prev]; + next[dragIndex] = [finalX, myVal]; + return next; // already ordered due to constraints + }); + }, + [dragIndex] + ); + + const commit = useCallback( + (pts: ChannelPoints) => { + onChange(sortPoints(pts)); + }, + [onChange] + ); + + const handlePointerUp = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + const c = canvasRef.current; + if (c) { + try { + c.releasePointerCapture(e.pointerId); + } catch { + /* ignore */ + } + } + setDragIndex(null); + commit(localPoints); + }, + [commit, localPoints] + ); + + const handlePointerCancel = useCallback( + (e: React.PointerEvent) => { + const c = canvasRef.current; + if (c) { + try { + c.releasePointerCapture(e.pointerId); + } catch { + /* ignore */ + } + } + setDragIndex(null); + commit(localPoints); + }, + [commit, localPoints] + ); + + const handleDoubleClick = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const mx = e.clientX - rect.left; + const my = e.clientY - rect.top; + const idx = getNearestPointIndex(c, localPoints, mx, my); + if (idx > 0 && idx < localPoints.length - 1) { + const next = localPoints.filter((_, i) => i !== idx); + setLocalPoints(next); + commit(next); + } + }, + [commit, localPoints] + ); + + // Observe size changes to redraw (responsive behavior) + useEffect(() => { + const c = canvasRef.current; + if (!c) { + return; + } + const ro = new ResizeObserver(() => { + drawHistogram(c, channel, histogram, localPoints); + }); + ro.observe(c); + return () => ro.disconnect(); + }, [channel, histogram, localPoints]); + + const resetPoints = useCallback(() => { + setLocalPoints(sortPoints(DEFAULT_POINTS)); + commit(DEFAULT_POINTS); + }, [commit]); + + return ( + + + + {title} + + } + aria-label="Reset" + size="sm" + variant="link" + onClick={resetPoints} + /> + + + + ); +}); + +RasterLayerCurvesAdjustmentsGraph.displayName = 'RasterLayerCurvesAdjustmentsGraph'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx index 65a16a7b4f9..708f7f29cd6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx @@ -9,6 +9,7 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave'; import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject'; import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform'; +import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments'; import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu'; import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu'; import { memo } from 'react'; @@ -21,10 +22,10 @@ export const RasterLayerMenuItems = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx new file mode 100644 index 00000000000..86fac78cb3e --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx @@ -0,0 +1,39 @@ +import { MenuItem } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsCancel, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import { makeDefaultRasterLayerAdjustments } from 'features/controlLayers/store/util'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiSlidersHorizontalBold } from 'react-icons/pi'; + +export const RasterLayerMenuItemsAdjustments = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const { t } = useTranslation(); + const layer = useAppSelector((s) => + s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id) + ); + const hasAdjustments = Boolean(layer?.adjustments); + const onToggleAdjustmentsPresence = useCallback(() => { + if (hasAdjustments) { + dispatch(rasterLayerAdjustmentsCancel({ entityIdentifier })); + } else { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: makeDefaultRasterLayerAdjustments('simple'), + }) + ); + } + }, [dispatch, entityIdentifier, hasAdjustments]); + + return ( + }> + {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')} + + ); +}); + +RasterLayerMenuItemsAdjustments.displayName = 'RasterLayerMenuItemsAdjustments'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx new file mode 100644 index 00000000000..42c45e1c36d --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerSimpleAdjustmentsEditor.tsx @@ -0,0 +1,118 @@ +import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsSimpleUpdated } from 'features/controlLayers/store/canvasSlice'; +import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors'; +import type { SimpleAdjustmentsConfig } from 'features/controlLayers/store/types'; +import React, { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +type AdjustmentSliderRowProps = { + label: string; + name: keyof SimpleAdjustmentsConfig; + onChange: (v: number) => void; + min?: number; + max?: number; + step?: number; +}; + +const AdjustmentSliderRow = ({ label, name, onChange, min = -1, max = 1, step = 0.01 }: AdjustmentSliderRowProps) => { + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const selectValue = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => + selectEntity(canvas, entityIdentifier)?.adjustments?.simple?.[name] ?? DEFAULT_SIMPLE_ADJUSTMENTS[name] + ); + }, [entityIdentifier, name]); + const value = useAppSelector(selectValue); + + return ( + + + {label} + + + + + ); +}; + +const DEFAULT_SIMPLE_ADJUSTMENTS = { + brightness: 0, + contrast: 0, + saturation: 0, + temperature: 0, + tint: 0, + sharpness: 0, +}; + +export const RasterLayerSimpleAdjustmentsEditor = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const { t } = useTranslation(); + const selectIsDisabled = useMemo(() => { + return createSelector( + selectCanvasSlice, + (canvas) => selectEntity(canvas, entityIdentifier)?.adjustments?.enabled !== true + ); + }, [entityIdentifier]); + const isDisabled = useAppSelector(selectIsDisabled); + + const onBrightness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { brightness: v } })), + [dispatch, entityIdentifier] + ); + const onContrast = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { contrast: v } })), + [dispatch, entityIdentifier] + ); + const onSaturation = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { saturation: v } })), + [dispatch, entityIdentifier] + ); + const onTemperature = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { temperature: v } })), + [dispatch, entityIdentifier] + ); + const onTint = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { tint: v } })), + [dispatch, entityIdentifier] + ); + const onSharpness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { sharpness: v } })), + [dispatch, entityIdentifier] + ); + + return ( + + + + + + + + + ); +}); + +RasterLayerSimpleAdjustmentsEditor.displayName = 'RasterLayerSimpleAdjustmentsEditor'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts index 6c55e949377..2b45f61b291 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts @@ -475,7 +475,7 @@ export abstract class CanvasEntityAdapterBase { this.log.trace({ rect }, 'Getting canvas'); // The opacity may have been changed in response to user selecting a different entity category, so we must restore // the original opacity before rendering the canvas - const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] }; + const attrs: GroupConfig = { opacity: this.state.opacity }; const canvas = this.renderer.getCanvas({ rect, attrs }); return canvas; }; @@ -74,4 +80,79 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase< const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked']; return omit(this.state, keysToOmit); }; + + private syncAdjustmentsFilter = () => { + const a = this.state.adjustments; + const apply = !!a && a.enabled; + // The filter operates on the renderer's object group; we can set filters at the group level via renderer + const group = this.renderer.konva.objectGroup; + if (apply) { + const filters = group.filters() ?? []; + let nextFilters = filters.filter((f) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter); + if (a.mode === 'simple') { + group.setAttr('adjustmentsSimple', a.simple); + group.setAttr('adjustmentsCurves', null); + nextFilters = [...nextFilters, AdjustmentsSimpleFilter]; + } else { + // Build LUTs and set curves attr + const master = buildCurveLUT(a.curves.master); + const r = buildCurveLUT(a.curves.r); + const g = buildCurveLUT(a.curves.g); + const b = buildCurveLUT(a.curves.b); + group.setAttr('adjustmentsCurves', { master, r, g, b }); + group.setAttr('adjustmentsSimple', null); + nextFilters = [...nextFilters, AdjustmentsCurvesFilter]; + } + group.filters(nextFilters); + this._throttledCacheRefresh(); + } else { + // Remove our filter if present + const filters = (group.filters() ?? []).filter( + (f) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter + ); + group.filters(filters); + group.setAttr('adjustmentsSimple', null); + group.setAttr('adjustmentsCurves', null); + this._throttledCacheRefresh(); + } + }; + + private _throttledCacheRefresh = throttle(() => this.renderer.syncKonvaCache(true), 50); + + private haveAdjustmentsChanged = (prevState: CanvasRasterLayerState, currState: CanvasRasterLayerState): boolean => { + const pa = prevState.adjustments; + const ca = currState.adjustments; + if (pa === ca) { + return false; + } + if (!pa || !ca) { + return true; + } + if (pa.enabled !== ca.enabled) { + return true; + } + if (pa.mode !== ca.mode) { + return true; + } + // simple params + const ps = pa.simple; + const cs = ca.simple; + if ( + ps.brightness !== cs.brightness || + ps.contrast !== cs.contrast || + ps.saturation !== cs.saturation || + ps.temperature !== cs.temperature || + ps.tint !== cs.tint || + ps.sharpness !== cs.sharpness + ) { + return true; + } + // curves reference (UI not implemented yet) - if arrays differ by ref, consider changed + const pc = pa.curves; + const cc = ca.curves; + if (pc !== cc) { + return true; + } + return false; + }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts index 34c5c9ac5de..044e6c52b18 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts @@ -3,6 +3,10 @@ * https://konvajs.org/docs/filters/Custom_Filter.html */ +import { clamp } from 'es-toolkit/compat'; +import { zCurvesAdjustmentsLUTs, zSimpleAdjustmentsConfig } from 'features/controlLayers/store/types'; +import type Konva from 'konva'; + /** * Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value. * This is useful for edge maps and other masks, to make the black areas transparent. @@ -20,3 +24,177 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => { imageData.data[i * 4 + 3] = Math.min(a, (cMin + cMax) / 2); } }; + +/** + * Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness). + * + * Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter. + */ +export const AdjustmentsSimpleFilter = function (this: Konva.Node, imageData: ImageData): void { + const paramsRaw = this.getAttr('adjustmentsSimple'); + const parseResult = zSimpleAdjustmentsConfig.safeParse(paramsRaw); + if (!parseResult.success) { + return; + } + const params = parseResult.data; + + const { brightness, contrast, saturation, temperature, tint, sharpness } = params; + + const data = imageData.data; + const len = data.length / 4; + const width = imageData.width; + const height = imageData.height; + + // Precompute factors + const brightnessShift = brightness * 255; // additive shift + const contrastFactor = 1 + contrast; // scale around 128 + + // Temperature/Tint multipliers + const tempK = 0.5; + const tintK = 0.5; + const rTempMul = 1 + temperature * tempK; + const bTempMul = 1 - temperature * tempK; + // Tint: green <-> magenta. Positive = magenta (R/B up, G down). Negative = green (G up, R/B down). + const t = clamp(tint, -1, 1) * tintK; + const mag = Math.abs(t); + const rTintMul = t >= 0 ? 1 + mag : 1 - mag; + const gTintMul = t >= 0 ? 1 - mag : 1 + mag; + const bTintMul = t >= 0 ? 1 + mag : 1 - mag; + + // Saturation matrix (HSL-based approximation via luma coefficients) + const lumaR = 0.2126; + const lumaG = 0.7152; + const lumaB = 0.0722; + const S = 1 + saturation; // 0..2 + const m00 = lumaR * (1 - S) + S; + const m01 = lumaG * (1 - S); + const m02 = lumaB * (1 - S); + const m10 = lumaR * (1 - S); + const m11 = lumaG * (1 - S) + S; + const m12 = lumaB * (1 - S); + const m20 = lumaR * (1 - S); + const m21 = lumaG * (1 - S); + const m22 = lumaB * (1 - S) + S; + + // First pass: apply per-pixel color adjustments (excluding sharpness) + for (let i = 0; i < len; i++) { + const idx = i * 4; + let r = data[idx + 0] as number; + let g = data[idx + 1] as number; + let b = data[idx + 2] as number; + const a = data[idx + 3] as number; + + // Brightness (additive) + r = r + brightnessShift; + g = g + brightnessShift; + b = b + brightnessShift; + + // Contrast around mid-point 128 + r = (r - 128) * contrastFactor + 128; + g = (g - 128) * contrastFactor + 128; + b = (b - 128) * contrastFactor + 128; + + // Temperature (R/B axis) and Tint (G vs Magenta) + r = r * rTempMul * rTintMul; + g = g * gTintMul; + b = b * bTempMul * bTintMul; + + // Saturation via matrix + const r2 = r * m00 + g * m01 + b * m02; + const g2 = r * m10 + g * m11 + b * m12; + const b2 = r * m20 + g * m21 + b * m22; + + data[idx + 0] = clamp(r2, 0, 255); + data[idx + 1] = clamp(g2, 0, 255); + data[idx + 2] = clamp(b2, 0, 255); + data[idx + 3] = a; + } + + // Optional sharpen (simple unsharp mask with 3x3 kernel) + if (Math.abs(sharpness) > 1e-3 && width > 2 && height > 2) { + const src = new Uint8ClampedArray(data); // copy of modified data + const a = Math.max(-1, Math.min(1, sharpness)) * 0.5; // amount + const center = 1 + 4 * a; + const neighbor = -a; + for (let y = 1; y < height - 1; y++) { + for (let x = 1; x < width - 1; x++) { + const idx = (y * width + x) * 4; + for (let c = 0; c < 3; c++) { + const centerPx = src[idx + c] ?? 0; + const leftPx = src[idx - 4 + c] ?? 0; + const rightPx = src[idx + 4 + c] ?? 0; + const topPx = src[idx - width * 4 + c] ?? 0; + const bottomPx = src[idx + width * 4 + c] ?? 0; + const v = centerPx * center + leftPx * neighbor + rightPx * neighbor + topPx * neighbor + bottomPx * neighbor; + data[idx + c] = clamp(v, 0, 255); + } + // preserve alpha + } + } + } +}; + +// Build a 256-length LUT from 0..255 control points (linear interpolation for v1) +export const buildCurveLUT = (points: Array<[number, number]>): number[] => { + if (!points || points.length === 0) { + return Array.from({ length: 256 }, (_, i) => i); + } + const pts = points + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]) + .sort((a, b) => a[0] - b[0]); + if ((pts[0]?.[0] ?? 0) !== 0) { + pts.unshift([0, pts[0]?.[1] ?? 0]); + } + const last = pts[pts.length - 1]; + if ((last?.[0] ?? 255) !== 255) { + pts.push([255, last?.[1] ?? 255]); + } + const lut = new Array(256); + let j = 0; + for (let x = 0; x <= 255; x++) { + while (j < pts.length - 2 && x > (pts[j + 1]?.[0] ?? 255)) { + j++; + } + const p0 = pts[j] ?? [0, 0]; + const p1 = pts[j + 1] ?? [255, 255]; + const [x0, y0] = p0; + const [x1, y1] = p1; + const t = x1 === x0 ? 0 : (x - x0) / (x1 - x0); + const y = y0 + (y1 - y0) * t; + lut[x] = clamp(Math.round(y), 0, 255); + } + return lut; +}; + +/** + * Per-layer curves adjustments filter (master, r, g, b) + * + * Parameters are read from the Konva node attr `adjustmentsCurves` set by the adapter. + */ +export const AdjustmentsCurvesFilter = function (this: Konva.Node, imageData: ImageData): void { + const paramsRaw = this.getAttr('adjustmentsCurves'); + const parseResult = zCurvesAdjustmentsLUTs.safeParse(paramsRaw); + if (!parseResult.success) { + return; + } + const params = parseResult.data; + + const { master, r, g, b } = params; + if (!master || !r || !g || !b) { + return; + } + const data = imageData.data; + const len = data.length / 4; + for (let i = 0; i < len; i++) { + const idx = i * 4; + const r0 = data[idx + 0] as number; + const g0 = data[idx + 1] as number; + const b0 = data[idx + 2] as number; + const rm = master[r0] ?? r0; + const gm = master[g0] ?? g0; + const bm = master[b0] ?? b0; + data[idx + 0] = clamp(r[rm] ?? rm, 0, 255); + data[idx + 1] = clamp(g[gm] ?? gm, 0, 255); + data[idx + 2] = clamp(b[bm] ?? bm, 0, 255); + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 46f1e620bdd..8c8987d5462 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -19,12 +19,16 @@ import type { CanvasEntityType, CanvasInpaintMaskState, CanvasMetadata, + ChannelName, + ChannelPoints, ControlLoRAConfig, EntityMovedByPayload, FillStyle, FLUXReduxImageInfluence, + RasterLayerAdjustments, RegionalGuidanceRefImageState, RgbColor, + SimpleAdjustmentsConfig, } from 'features/controlLayers/store/types'; import { calculateNewSize, @@ -96,6 +100,7 @@ import { initialFLUXRedux, initialIPAdapter, initialT2IAdapter, + makeDefaultRasterLayerAdjustments, } from './util'; const slice = createSlice({ @@ -104,6 +109,96 @@ const slice = createSlice({ reducers: { // undoable canvas state //#region Raster layers + rasterLayerAdjustmentsSet: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, adjustments } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (adjustments === null) { + delete layer.adjustments; + return; + } + if (!layer.adjustments) { + layer.adjustments = makeDefaultRasterLayerAdjustments(adjustments.mode ?? 'simple'); + } + layer.adjustments = merge(layer.adjustments, adjustments); + }, + rasterLayerAdjustmentsReset: (state, action: PayloadAction>) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.simple = makeDefaultRasterLayerAdjustments('simple').simple; + layer.adjustments.curves = makeDefaultRasterLayerAdjustments('curves').curves; + }, + rasterLayerAdjustmentsCancel: (state, action: PayloadAction>) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + delete layer.adjustments; + }, + rasterLayerAdjustmentsModeChanged: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, mode } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.mode = mode; + }, + rasterLayerAdjustmentsSimpleUpdated: ( + state, + action: PayloadAction }, 'raster_layer'>> + ) => { + const { entityIdentifier, simple } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.simple = merge(layer.adjustments.simple, simple); + }, + rasterLayerAdjustmentsCurvesUpdated: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier, channel, points } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.curves[channel] = points; + }, + rasterLayerAdjustmentsEnabledToggled: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.enabled = !layer.adjustments.enabled; + }, + rasterLayerAdjustmentsCollapsedToggled: ( + state, + action: PayloadAction> + ) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer?.adjustments) { + return; + } + layer.adjustments.collapsed = !layer.adjustments.collapsed; + }, rasterLayerAdded: { reducer: ( state, @@ -1658,6 +1753,15 @@ export const { entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, + // Raster layer adjustments + rasterLayerAdjustmentsSet, + rasterLayerAdjustmentsCancel, + rasterLayerAdjustmentsReset, + rasterLayerAdjustmentsModeChanged, + rasterLayerAdjustmentsEnabledToggled, + rasterLayerAdjustmentsCollapsedToggled, + rasterLayerAdjustmentsSimpleUpdated, + rasterLayerAdjustmentsCurvesUpdated, entityDeleted, entityArrangedForwardOne, entityArrangedToFront, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index d0a3414a572..d5012bd88c9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -378,11 +378,57 @@ const zControlLoRAConfig = z.object({ }); export type ControlLoRAConfig = z.infer; +/** + * All simple params normalized to `[-1, 1]` except sharpness `[0, 1]`. + * + * - Brightness: -1 (darken) to 1 (brighten) + * - Contrast: -1 (decrease contrast) to 1 (increase contrast) + * - Saturation: -1 (desaturate) to 1 (saturate) + * - Temperature: -1 (cooler/blue) to 1 (warmer/yellow) + * - Tint: -1 (greener) to 1 (more magenta) + * - Sharpness: 0 (no sharpening) to 1 (maximum sharpening) + */ +export const zSimpleAdjustmentsConfig = z.object({ + brightness: z.number().gte(-1).lte(1), + contrast: z.number().gte(-1).lte(1), + saturation: z.number().gte(-1).lte(1), + temperature: z.number().gte(-1).lte(1), + tint: z.number().gte(-1).lte(1), + sharpness: z.number().gte(0).lte(1), +}); +export type SimpleAdjustmentsConfig = z.infer; + +const zUint8 = z.number().int().min(0).max(255); +const zChannelPoints = z.array(z.tuple([zUint8, zUint8])).min(2); +const zChannelName = z.enum(['master', 'r', 'g', 'b']); +const zCurvesAdjustmentsConfig = z.record(zChannelName, zChannelPoints); +export type ChannelName = z.infer; +export type ChannelPoints = z.infer; +export type CurvesAdjustmentsConfig = z.infer; + +/** + * The curves adjustments are stored as LUTs in the Konva node attributes. Konva will use these values when applying + * the filter. + */ +export const zCurvesAdjustmentsLUTs = z.record(zChannelName, z.array(zUint8)); + +const zRasterLayerAdjustments = z.object({ + version: z.literal(1), + enabled: z.boolean(), + collapsed: z.boolean(), + mode: z.enum(['simple', 'curves']), + simple: zSimpleAdjustmentsConfig, + curves: zCurvesAdjustmentsConfig, +}); +export type RasterLayerAdjustments = z.infer; + const zCanvasRasterLayerState = zCanvasEntityBase.extend({ type: z.literal('raster_layer'), position: zCoordinate, opacity: zOpacity, objects: z.array(zCanvasObjectState), + // Optional per-layer color adjustments (simple + curves). When undefined, no adjustments are applied. + adjustments: zRasterLayerAdjustments.optional(), }); export type CanvasRasterLayerState = z.infer; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts index 88798a048df..cb6e816e320 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts @@ -15,6 +15,7 @@ import type { Gemini2_5ReferenceImageConfig, ImageWithDims, IPAdapterConfig, + RasterLayerAdjustments, RefImageState, RgbColor, T2IAdapterConfig, @@ -118,6 +119,32 @@ export const initialControlLoRA: ControlLoRAConfig = { weight: 0.75, }; +export const makeDefaultRasterLayerAdjustments = (mode: 'simple' | 'curves' = 'simple'): RasterLayerAdjustments => ({ + version: 1, + enabled: true, + collapsed: false, + mode, + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, +}); + export const getReferenceImageState = (id: string, overrides?: PartialDeep): RefImageState => { const entityState: RefImageState = { id, @@ -187,6 +214,7 @@ export const getRasterLayerState = ( objects: [], opacity: 1, position: { x: 0, y: 0 }, + adjustments: undefined, }; merge(entityState, overrides); return entityState;