Shift-Net: 深度特征重排的图像修复Image Inpainting via Deep Feature Rearrangement

该代码是基于torch7,但是如果你喜欢pytorch,请点击这里Shift-Net_pytorch。为使用pytorch的实现。

                   

入门

我们希望你有一个nvidia GPU并安装了CUDA。该代码现在不支持在CPU上运行。

安装

luarocks install nngraph
luarocks install cudnn
luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
  • 克隆这份代码:
git clone https://github.com/Zhaoyi-Yan/Shift-Net
 cd Shift-Net

下载预先训练的模型

bash scripts/download_models.sh

该模型将被下载并解压缩。

训练

  • 下载你自己的数据集。
  • train.lua根据数据集的路径更改选项。通常,你至少应该指定三个选项。他们是DATA_ROOTphase name

例如:

DATA_ROOT: ./datasets/Paris_StreetView_Dataset/

phase: paris_train

name: paris_train_shiftNet

这意味着训练图像在文件夹下./datasets/Paris_StreetView_Dataset/paris_train/。至于name,它给你的实验一个名字,例如paris_train_shiftNet。训练时,checkpoint存储在文件夹下 ./checkpoints/paris_train_shiftNet/

  • 训练模型:
th train.lua
  • 在浏览器上显示临时结果。设置display = 1,然后打开另一个控制台,
th -ldisplay.start
  • 在浏览器中打开此URL:http://localhost:8000
  • 如果你想训练一个可以处理随机蒙版的模型,然后设置

mask_type: 'random'

fixed_mask: false

您可以设置一个浮点数res,例如,local res = 0.06它越低,输出就越连续。在测试模型时,这两个选项应该与您模型训练的选项保持一致。

测试

测试之前,您应该改变DATA_ROOTphasenamecheckpoint_dirwhich_epoch。例如,如果你想测试你的训练模型的第30个时代,那么

DATA_ROOT: ./datasets/Paris_StreetView_Dataset/

phase: paris_train

name: paris_train_shiftNet

checkpoint_dir./checkpoints/

which_epoch: '30'

前两个选项确定数据集的位置,其余的定义存储模型的文件夹。

  • 最后,测试模型:
th test.lua

致谢

我们从pix2pixDCGAN中受益匪浅。数据加载器是从pix2pix进行修改的,并且实例标准化的实现借鉴了Instance Normalization。shift操作受到style-swap启发。

https://github.com/Zhaoyi-Yan/Shift-Net

转载请注明:《Shift-Net: 深度特征重排的图像修复Image Inpainting via Deep Feature Rearrangement

发表评论