-
Notifications
You must be signed in to change notification settings - Fork 6.5k
PipelineTesterMixin parameter configuration refactor #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PipelineTesterMixin parameter configuration refactor #2502
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
0c1420a to
bf21775
Compare
bf21775 to
e788d12
Compare
| # These are canonical sets of parameters for different types of pipelines. | ||
| # They are set on subclasses of `PipelineTesterMixin` as `params` and | ||
| # `batch_params`. | ||
| # | ||
| # If a pipeline's set of arguments has minor changes from one of the common sets | ||
| # of arguments, do not make modifications to the existing common sets of arguments. | ||
| # I.e. a text to image pipeline with non-configurable height and width arguments | ||
| # should set its attribute as `params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. | ||
|
|
||
| TEXT_TO_IMAGE_PARAMS = frozenset( | ||
| [ | ||
| "prompt", | ||
| "height", | ||
| "width", | ||
| "guidance_scale", | ||
| "negative_prompt", | ||
| "prompt_embeds", | ||
| "negative_prompt_embeds", | ||
| "cross_attention_kwargs", | ||
| ] | ||
| ) | ||
|
|
||
| TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default param values for families of pipelines
| params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} | ||
| required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} | ||
| batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example modifying default params
| # Canonical parameters that are passed to `__call__` regardless | ||
| # of the type of pipeline. They are always optional and have common | ||
| # sense default values. | ||
| required_optional_params = frozenset( | ||
| [ | ||
| "num_inference_steps", | ||
| "num_images_per_prompt", | ||
| "eta", | ||
| "generator", | ||
| "latents", | ||
| "output_type", | ||
| "return_dict", | ||
| "callback", | ||
| "callback_steps", | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new set of required_optional_params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love it.
| @property | ||
| def batch_params(self) -> frozenset: | ||
| raise NotImplementedError( | ||
| "You need to set the attribute `batch_params` in the child test class. " | ||
| "`batch_params` are the parameters required to be batched when passed to the pipeline's " | ||
| "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " | ||
| "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " | ||
| "set of batch arguments has minor changes from one of the common sets of batch arguments, " | ||
| "do not make modifications to the existing common sets of batch arguments. I.e. a text to " | ||
| "image pipeline `negative_prompt` is not batched should set the attribute as " | ||
| "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " | ||
| "See existing pipeline tests for reference." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_params with error message/docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
|
note failing onnx tests seem unrelated |
| "callback", | ||
| "latents", | ||
| "callback_steps", | ||
| "output_type", | ||
| "eta", | ||
| "num_images_per_prompt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd group them into variable and then use it instead of defining ad-hoc. Applies here and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually ok for me as is - I'd just move "eta" out of the "required" optional base params. ETA is too specific & we shouldn't incentivize adding it in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice! My only suggestion is to move "eta" out of the required/default list of optional parameters. ETA is only relevant for DDIM and is also generally not used.
Going forward I actually don't think it's a good idea to have a "eta" argument for pipelines
tests/test_pipelines_common.py
Outdated
| [ | ||
| "num_inference_steps", | ||
| "num_images_per_prompt", | ||
| "eta", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "eta", |
I'd remove "eta" here. "eta" is a very specific case that doesn't apply to all pipelines. We probably shouldn't have added this at all in the first place
| @property | ||
| def batch_params(self) -> frozenset: | ||
| raise NotImplementedError( | ||
| "You need to set the attribute `batch_params` in the child test class. " | ||
| "`batch_params` are the parameters required to be batched when passed to the pipeline's " | ||
| "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as " | ||
| "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's " | ||
| "set of batch arguments has minor changes from one of the common sets of batch arguments, " | ||
| "do not make modifications to the existing common sets of batch arguments. I.e. a text to " | ||
| "image pipeline `negative_prompt` is not batched should set the attribute as " | ||
| "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. " | ||
| "See existing pipeline tests for reference." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
| "callback", | ||
| "latents", | ||
| "callback_steps", | ||
| "output_type", | ||
| "eta", | ||
| "num_images_per_prompt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually ok for me as is - I'd just move "eta" out of the "required" optional base params. ETA is too specific & we shouldn't incentivize adding it in the future
|
|
||
| UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"]) | ||
|
|
||
| UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice! This is really helpful to classify the pipelines into certain tasks and should help us write new tests going forward
* attend and excite batch test causing timeouts * PipelineTesterMixin argument configuration refactor * error message text re: @yiyixuxu * remove eta re: @patrickvonplaten
* attend and excite batch test causing timeouts * PipelineTesterMixin argument configuration refactor * error message text re: @yiyixuxu * remove eta re: @patrickvonplaten
Currently used test config
allowed_required_argspromptwhich you would think is the canonical required parameter is really technically optional in most pipelines as we allow overriding it withprompt_embeddings.required_optional_paramsThis is a good config but we could add a few more of the commonly included optional parameters
num_inference_steps_argsThis was a bad hack that I added for testing batching on the unclip pipelines with many different inference stages.
Refactor
num_inference_steps_args-> remove and pass as argument to test helper functionrequired_optional_parametersis good but we want to add a few more of the default optional parameters i.e. callback/callback_steps.Really we want to separate
allowed_required_argsinto two parts:paramsbatch_paramsBoth
paramsandbatch_paramsare generally determined by the type of pipeline -- i.e. text to image, image variation. They are configured on the specific pipeline subclass with a good error message explaining why and how to configure them when they are not set. Because we don't have strict interfaces (i.e. not all image variation pipelines take height and width) on pipelines nor families of pipelines,paramsare easily tweaked from provided defaults with operations onfrozenset.