Skip to content

AsymmetricAutoencoderKL: missing generator argument in decode() called from forward()  #8317

@townwish4git

Description

@townwish4git

Describe the bug

AsymmetricAutoencoderKL.decode() has parameter generator,

def decode(
        self,
        z: torch.FloatTensor,
        generator: Optional[torch.Generator] = None,
        image: Optional[torch.FloatTensor] = None,
        mask: Optional[torch.FloatTensor] = None,
        return_dict: bool = True,
    ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
    ...

But when forward() calls self.decode(), only three positional arguments are passed in,

...
dec = self.decode(z, sample, mask).sample
...

which means:

  1. argument z => parameter z;
  2. argument sample => parameter generator;
  3. argument mask => parameter image;
  4. default None => parameter mask

it appears to be a bug and should be corrected to

- dec = self.decode(z, sample, mask).sample
+ dec = self.decode(z, generator, sample, mask).sample

Reproduction

from diffusers import AsymmetricAutoencoderKL

...

Logs

No response

System Info

Environment-independent

Who can help?

@sayakpaul @DN6 @yiyixuxu @cross-attention

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions