Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions invokeai/frontend/web/.storybook/ReduxInit.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';

import { useAppDispatch } from '../src/app/store/storeHooks';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { modelChanged } from 'features/controlLayers/store/actions';
/**
* Initializes some state for storybook. Must be in a different component
* so that it is run inside the redux context.
Expand All @@ -13,7 +13,9 @@ export const ReduxInit = memo(({ children }: PropsWithChildren) => {
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
modelChanged({
model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' },
})
);
}, [dispatch]);

Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"@invoke-ai/ui-library": "^0.0.47",
"@nanostores/react": "^1.0.0",
"@observ33r/object-equals": "^1.1.5",
"@reduxjs/toolkit": "2.8.2",
"@reduxjs/toolkit": "2.9.0",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.8.2",
"ag-psd": "^28.2.2",
Expand Down
10 changes: 5 additions & 5 deletions invokeai/frontend/web/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import type { Middleware, UnknownAction } from '@reduxjs/toolkit';
import { injectTabActionContext } from 'app/store/util';
import { isCanvasInstanceAction } from 'features/controlLayers/store/canvasSlice';
import { selectActiveCanvasId, selectActiveTab } from 'features/controlLayers/store/selectors';
import { isTabInstanceParamsAction } from 'features/controlLayers/store/tabSlice';

export const actionContextMiddleware: Middleware = (store) => (next) => (action) => {
const currentAction = action as UnknownAction;

if (isTabActionContextRequired(currentAction)) {
const state = store.getState();
const tab = selectActiveTab(state);
const canvasId = tab === 'canvas' ? selectActiveCanvasId(state) : undefined;

injectTabActionContext(currentAction, tab, canvasId);
}

return next(action);
};

const isTabActionContextRequired = (action: UnknownAction) => {
return isTabInstanceParamsAction(action) || isCanvasInstanceAction(action);
};
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { selectActiveTabParams, setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: (action, { getState, dispatch }) => {
effect: (action, api) => {
const { getState, dispatch } = api;
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().params.infillMethod;
const infillMethod = selectActiveTabParams(getState()).infillMethod;

if (!infill_methods.includes(infillMethod)) {
// If the selected infill method does not exist, prefer 'lama' if it's in the list, otherwise 'tile'.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { selectCanvases } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
Expand All @@ -19,12 +19,12 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS

const state = getState();
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const canvases = selectCanvases(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);

deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
const imageUsage = getImageUsage(nodes, canvases, upscale, refImages, image_name);

if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/store';
import { modelChanged } from 'features/controlLayers/store/actions';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectActiveCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraIsEnabledChanged, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
import { selectActiveTabParams, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectActiveCanvas,
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
Expand All @@ -31,7 +32,8 @@ const log = logger('models');
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
effect: (action, api) => {
const { getState, dispatch } = api;
const state = getState();
const result = zParameterModel.safeParse(action.payload);

Expand All @@ -42,22 +44,23 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =

const newModel = result.data;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
const params = selectActiveTabParams(state);
const didBaseModelChange = params.model?.base !== newBase;

if (didBaseModelChange) {
// we may need to reset some incompatible submodels
let modelsUpdatedDisabledOrCleared = 0;

// handle incompatible loras
state.loras.loras.forEach((lora) => {
selectAddedLoRAs(state).forEach((lora) => {
if (lora.model.base !== newBase) {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: false }));
modelsUpdatedDisabledOrCleared += 1;
}
});

// handle incompatible vae
const { vae } = state.params;
const { vae } = params;
if (vae && vae.base !== newBase) {
dispatch(vaeSelected(null));
modelsUpdatedDisabledOrCleared += 1;
Expand Down Expand Up @@ -118,7 +121,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;

// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasState = selectActiveCanvas(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
Expand Down Expand Up @@ -152,14 +155,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}

dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
dispatch(modelChanged({ model: newModel, previousModel: params.model }));

const modelBase = selectBboxModelBase(state);

if (modelBase !== state.params.model?.base) {
if (modelBase !== params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const isStaging = selectActiveCanvasIsStaging(state);
if (!isStaging) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import { modelChanged } from 'features/controlLayers/store/actions';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { loraDeleted, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
fluxVAESelected,
modelChanged,
refinerModelChanged,
selectActiveTabParams,
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { selectActiveCanvas } from 'features/controlLayers/store/selectors';
import {
getEntityIdentifier,
isFLUXReduxConfig,
Expand Down Expand Up @@ -103,7 +104,7 @@ type ModelHandler = (
) => undefined;

const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const selectedMainModel = state.params.model;
const selectedMainModel = selectActiveTabParams(state).model;
const allMainModels = models.filter(isNonRefinerMainModelConfig).sort((a) => (a.base === 'sdxl' ? -1 : 1));

const firstModel = allMainModels[0];
Expand Down Expand Up @@ -144,7 +145,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
const selectedRefinerModel = state.params.refinerModel;
const selectedRefinerModel = selectActiveTabParams(state).refinerModel;

// `null` is a valid refiner model - no need to do anything.
if (selectedRefinerModel === null) {
Expand All @@ -168,7 +169,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedVAEModel = state.params.vae;
const selectedVAEModel = selectActiveTabParams(state).vae;

// `null` is a valid VAE - it means "use the VAE baked into the currently-selected main model"
if (selectedVAEModel === null) {
Expand All @@ -193,7 +194,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {

const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
const loraModels = models.filter(isLoRAModelConfig);
state.loras.loras.forEach((lora) => {
selectAddedLoRAs(state).forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
Expand Down Expand Up @@ -221,7 +222,7 @@ const handleVideoModels: ModelHandler = (models, state, dispatch, log) => {

const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const caModels = models.filter(isControlLayerModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
selectActiveCanvas(state).controlLayers.entities.forEach((entity) => {
const selectedControlAdapterModel = entity.controlAdapter.model;
// `null` is a valid control adapter model - no need to do anything.
if (!selectedControlAdapterModel) {
Expand Down Expand Up @@ -256,7 +257,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isRegionalGuidanceIPAdapterConfig(config)) {
return;
Expand Down Expand Up @@ -299,7 +300,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
return;
Expand Down Expand Up @@ -417,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =
};

const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const selectedT5EncoderModel = selectActiveTabParams(state).t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down Expand Up @@ -445,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const selectedCLIPEmbedModel = selectActiveTabParams(state).clipEmbedModel;
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down Expand Up @@ -473,7 +474,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
};

const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedFLUXVAEModel = state.params.fluxVAE;
const selectedFLUXVAEModel = selectActiveTabParams(state).fluxVAE;
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));

// If the currently selected model is available, we don't need to do anything
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectActiveCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
selectActiveTabParams,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
Expand All @@ -13,6 +14,7 @@ import {
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { selectActiveTab } from 'features/controlLayers/store/selectors';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
isParameterCFGRescaleMultiplier,
Expand All @@ -26,18 +28,18 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';

export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
effect: async (action, api) => {
const { dispatch, getState } = api;
const state = getState();

const currentModel = state.params.model;
const currentModel = selectActiveTabParams(state).model;

if (!currentModel) {
return;
Expand Down Expand Up @@ -115,7 +117,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
const setSizeOptions = { updateAspectRatio: true, clamp: true };

const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const isStaging = selectActiveCanvasIsStaging(state);

const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
Expand Down
Loading