Skip to content

加入循环一致性损失的带条件的ddpm模型

nebulaHZH/cycle-ddpm

Repository files navigation

Cycle-DDPM

添加了循环一致性损失得DDPM模型

项目结构

- train_ddpm.py 训练该模型,基础配置文件
- cycle_ddpm.py 加入了cycle_loss的DDPM模型
- diffusion_unet.py 用于DDPM的Unet模型,采用的Hugging face的专用于扩散模型的Unet
- plot.py 用于画图
- scehedule.py 设置Diffusion model 基本的相关函数
- load.py 用于加载数据集,数据集的格式为
    - t1:
        - 0001.png
        - 0002.png
        - ......
    - t2:
        - 0001.png
        - 0002.png
        - ......
- test.py 用于测试模型的转换和生成效果

使用方法

  1. 配置数据集
    • 数据集的格式为png
  2. 安装依赖
    pip install -r requirements.txt
  3. 训练模型
    • 运行train_ddpm.py
  4. 测试模型
    • 运行test.py

num_train_timesteps=500, num_inference_timesteps=20, # 增加推理步数提高生成质量 beta_start=0.0001, # 降低beta起始值 beta_end=0.02, # 调整beta结束值

About

加入循环一致性损失的带条件的ddpm模型

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages