Skip to content

Commit

Permalink
feat(ui,api): add guidance as a default setting option for FLUX models
Browse files Browse the repository at this point in the history
  • Loading branch information
Mary Hipp authored and maryhipp committed Sep 30, 2024
1 parent ca55ef1 commit c224971
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 5 deletions.
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class MainModelDefaultSettings(BaseModel):
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")

model_config = ConfigDict(extra="forbid")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaS
import {
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
setScheduler,
setSteps,
vaePrecisionChanged,
Expand All @@ -13,6 +14,7 @@ import { setDefaultSettings } from 'features/parameters/store/actions';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterGuidance,
isParameterHeight,
isParameterPrecision,
isParameterScheduler,
Expand Down Expand Up @@ -49,7 +51,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}

if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height, guidance } =
modelConfig.default_settings;

if (vae) {
Expand All @@ -73,6 +75,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}

if (guidance) {
if (isParameterGuidance(guidance)) {
dispatch(setGuidance(guidance));
}
}

if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { MainModelConfig } from 'services/api/types';

const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
const { guidance: fluxGuidance } = config.flux;

return {
initialSteps: steps.initial,
Expand All @@ -16,6 +17,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
initialVaePrecision: vaePrecision,
initialWidth: width.initial,
initialHeight: height.initial,
initialGuidance: fluxGuidance.initial,
};
});

Expand All @@ -28,6 +30,7 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
initialVaePrecision,
initialWidth,
initialHeight,
initialGuidance,
} = useAppSelector(initialStatesSelector);

const defaultSettingsDefaults = useMemo(() => {
Expand Down Expand Up @@ -64,6 +67,10 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
isEnabled: !isNil(modelConfig?.default_settings?.height),
value: modelConfig?.default_settings?.height || initialHeight,
},
guidance: {
isEnabled: !isNil(modelConfig?.default_settings?.guidance),
value: modelConfig?.default_settings?.guidance || initialGuidance,
},
};
}, [
modelConfig,
Expand All @@ -74,6 +81,7 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
initialCfgRescaleMultiplier,
initialWidth,
initialHeight,
initialGuidance,
]);

return defaultSettingsDefaults;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { selectGuidanceConfig } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';

type DefaultGuidanceType = MainModelDefaultSettingsFormData['guidance'];

export const DefaultGuidance = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
const { field } = useController(props);

const config = useAppSelector(selectGuidanceConfig);
const { t } = useTranslation();
const marks = useMemo(
() => [
config.sliderMin,
Math.floor(config.sliderMax - (config.sliderMax - config.sliderMin) / 2),
config.sliderMax,
],
[config.sliderMax, config.sliderMin]
);

const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultGuidanceType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);

const value = useMemo(() => {
return (field.value as DefaultGuidanceType).value;
}, [field.value]);

const isDisabled = useMemo(() => {
return !(field.value as DefaultGuidanceType).isEnabled;
}, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramGuidance">
<FormLabel>{t('parameters.guidance')}</FormLabel>
</InformationalPopover>
<SettingToggle control={props.control} name="guidance" />
</Flex>

<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={config.sliderMin}
max={config.sliderMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={config.numberInputMin}
max={config.numberInputMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultGuidance.displayName = 'DefaultGuidance';
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type { MainModelConfig } from 'services/api/types';

import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
import { DefaultGuidance } from './DefaultGuidance';
import { DefaultScheduler } from './DefaultScheduler';
import { DefaultSteps } from './DefaultSteps';
import { DefaultVae } from './DefaultVae';
Expand All @@ -36,6 +37,7 @@ export type MainModelDefaultSettingsFormData = {
cfgRescaleMultiplier: FormField<number>;
width: FormField<number>;
height: FormField<number>;
guidance: FormField<number>;
};

type Props = {
Expand All @@ -46,6 +48,10 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
const { t } = useTranslation();

const isFlux = useMemo(() => {
return modelConfig.base === 'flux';
}, [modelConfig]);

const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
const optimalDimension = useMemo(() => {
const modelBase = modelConfig?.base;
Expand All @@ -72,6 +78,7 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
width: data.width.isEnabled ? data.width.value : null,
height: data.height.isEnabled ? data.height.value : null,
guidance: data.guidance.isEnabled ? data.guidance.value : null,
};

updateModel({
Expand Down Expand Up @@ -118,11 +125,12 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {

<SimpleGrid columns={2} gap={8}>
<DefaultVae control={control} name="vae" />
<DefaultVaePrecision control={control} name="vaePrecision" />
<DefaultScheduler control={control} name="scheduler" />
{!isFlux && <DefaultVaePrecision control={control} name="vaePrecision" />}
{!isFlux && <DefaultScheduler control={control} name="scheduler" />}
<DefaultSteps control={control} name="steps" />
<DefaultCfgScale control={control} name="cfgScale" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
{isFlux && <DefaultGuidance control={control} name="guidance" />}
{!isFlux && <DefaultCfgScale control={control} name="cfgScale" />}
{!isFlux && <DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />}
<DefaultWidth control={control} optimalDimension={optimalDimension} />
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</SimpleGrid>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
// #region Guidance parameter
const zParameterGuidance = z.number().min(1);
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
zParameterGuidance.safeParse(val).success;
// #endregion

// #region CFG Rescale Multiplier
Expand Down
5 changes: 5 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11037,6 +11037,11 @@ export type components = {
* @description Default height for this model
*/
height?: number | null;
/**
* Guidance
* @description Default Guidance for this model
*/
guidance?: number | null;
};
/**
* Main Model
Expand Down

0 comments on commit c224971

Please sign in to comment.