-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Description
Currently, the following test is failing:
RUN_SLOW=1 pytest tests/lora/test_lora_layers_peft.py::LoraSDXLIntegrationTests::test_sdxl_1_0_fuse_unfuse_allError:
self = <test_lora_layers_peft.LoraSDXLIntegrationTests testMethod=test_sdxl_1_0_fuse_unfuse_all>
def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict())
pipe.load_lora_weights(
"davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16
)
fused_te_state_dict = pipe.text_encoder.state_dict()
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
unet_state_dict = pipe.unet.state_dict()
for key, value in text_encoder_1_sd.items():
> self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
E KeyError: 'text_model.encoder.layers.0.self_attn.k_proj.weight'When I investigated fused_te_state_dict with the following:
targeted_layers = list(filter(lambda x: "text_model.encoder.layers.0.self_attn" in x, fused_te_state_dict.keys()))
print(targeted_layers)I got:
['text_model.encoder.layers.0.self_attn.k_proj.base_layer.weight', 'text_model.encoder.layers.0.self_attn.k_proj.base_layer.bias', 'text_model.encoder.layers.0.self_attn.k_proj.lora_A.default_0.weight', 'text_model.encoder.layers.0.self_attn.k_proj.lora_B.default_0.weight', 'text_model.encoder.layers.0.self_attn.v_proj.base_layer.weight', 'text_model.encoder.layers.0.self_attn.v_proj.base_layer.bias', 'text_model.encoder.layers.0.self_attn.v_proj.lora_A.default_0.weight', 'text_model.encoder.layers.0.self_attn.v_proj.lora_B.default_0.weight', 'text_model.encoder.layers.0.self_attn.q_proj.base_layer.weight', 'text_model.encoder.layers.0.self_attn.q_proj.base_layer.bias', 'text_model.encoder.layers.0.self_attn.q_proj.lora_A.default_0.weight', 'text_model.encoder.layers.0.self_attn.q_proj.lora_B.default_0.weight', 'text_model.encoder.layers.0.self_attn.out_proj.base_layer.weight', 'text_model.encoder.layers.0.self_attn.out_proj.base_layer.bias', 'text_model.encoder.layers.0.self_attn.out_proj.lora_A.default_0.weight', 'text_model.encoder.layers.0.self_attn.out_proj.lora_B.default_0.weight']So the error makes sense to me. @BenjaminBossan @younesbelkada could you take a look here?
Metadata
Metadata
Assignees
Labels
No labels