-
Notifications
You must be signed in to change notification settings - Fork 6.5k
PipelineTesterMixin parameter configuration refactor #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dde7095
e788d12
86ec209
c50eb66
98df334
470ec3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # These are canonical sets of parameters for different types of pipelines. | ||
| # They are set on subclasses of `PipelineTesterMixin` as `params` and | ||
| # `batch_params`. | ||
| # | ||
| # If a pipeline's set of arguments has minor changes from one of the common sets | ||
| # of arguments, do not make modifications to the existing common sets of arguments. | ||
| # I.e. a text to image pipeline with non-configurable height and width arguments | ||
| # should set its attribute as `params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. | ||
|
|
||
| TEXT_TO_IMAGE_PARAMS = frozenset( | ||
| [ | ||
| "prompt", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| "negative_prompt", | ||
| "prompt_embeds", | ||
| "negative_prompt_embeds", | ||
| "cross_attention_kwargs", | ||
| ] | ||
| ) | ||
|
|
||
| TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) | ||
|
|
||
| IMAGE_VARIATION_PARAMS = frozenset( | ||
| [ | ||
| "image", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| ] | ||
| ) | ||
|
|
||
| IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"]) | ||
|
|
||
| TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset( | ||
| [ | ||
| "prompt", | ||
| "image", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| "negative_prompt", | ||
| "prompt_embeds", | ||
| "negative_prompt_embeds", | ||
| ] | ||
| ) | ||
|
|
||
| TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"]) | ||
|
|
||
| TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( | ||
| [ | ||
| # Text guided image variation with an image mask | ||
| "prompt", | ||
| "image", | ||
| "mask_image", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| "negative_prompt", | ||
| "prompt_embeds", | ||
| "negative_prompt_embeds", | ||
| ] | ||
| ) | ||
|
|
||
| TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"]) | ||
|
|
||
| IMAGE_INPAINTING_PARAMS = frozenset( | ||
| [ | ||
| # image variation with an image mask | ||
| "image", | ||
| "mask_image", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| ] | ||
| ) | ||
|
|
||
| IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"]) | ||
|
|
||
| IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset( | ||
| [ | ||
| "example_image", | ||
| "image", | ||
| "mask_image", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| ] | ||
| ) | ||
|
|
||
| IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"]) | ||
|
|
||
| CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"]) | ||
|
|
||
| CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"]) | ||
|
|
||
| UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"]) | ||
|
|
||
| UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([]) | ||
|
|
||
| UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) | ||
|
|
||
| UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super nice! This is really helpful to classify the pipelines into certain tasks and should help us write new tests going forward |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ | |
| from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device | ||
| from diffusers.utils.testing_utils import require_torch_gpu, skip_mps | ||
|
|
||
| from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS | ||
| from ...test_pipelines_common import PipelineTesterMixin | ||
|
|
||
|
|
||
|
|
@@ -41,6 +42,9 @@ | |
|
|
||
| class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): | ||
| pipeline_class = StableDiffusionImg2ImgPipeline | ||
| params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} | ||
| required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} | ||
| batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS | ||
|
Comment on lines
+45
to
+47
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Example modifying default params |
||
|
|
||
| def get_dummy_components(self): | ||
| torch.manual_seed(0) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default param values for families of pipelines