添加了循环一致性损失得DDPM模型
- train_ddpm.py 训练该模型,基础配置文件
- cycle_ddpm.py 加入了cycle_loss的DDPM模型
- diffusion_unet.py 用于DDPM的Unet模型,采用的Hugging face的专用于扩散模型的Unet
- plot.py 用于画图
- scehedule.py 设置Diffusion model 基本的相关函数
- load.py 用于加载数据集,数据集的格式为
- t1:
- 0001.png
- 0002.png
- ......
- t2:
- 0001.png
- 0002.png
- ......
- test.py 用于测试模型的转换和生成效果
- 配置数据集
- 数据集的格式为png
- 安装依赖
pip install -r requirements.txt
- 训练模型
- 运行train_ddpm.py
- 测试模型
- 运行test.py
num_train_timesteps=500, num_inference_timesteps=20, # 增加推理步数提高生成质量 beta_start=0.0001, # 降低beta起始值 beta_end=0.02, # 调整beta结束值