PredRNN++
视频预测问题是在给定一定数量的视频帧前提下,算法自动预测生成后面的视频帧像素,比如给出舞蹈视频前半段,算法生成后半段,有趣且有广泛的应用前景。
这是 PredRNN++ 算法的TensoFlow实现, 一种用于视频预测的循环神经网络,算法来自于下列论文:
PredRNN++: Towards A Resolution of the Deep-in-Time Dilemma in Spatiotemporal Predictive Learning, by Yunbo Wang, Zhifeng Gao, Mingsheng Long, Jianmin Wang and Philip S. Yu.
文章摘要:PredRNN ++,一个改进的视频预测学习循环网络。为了追求更高的时空建模能力,我们的方法通过利用一种新颖的循环单元来增加相邻状态之间的转换深度,该单元被命名为因果LSTM,用于在级联机制中重新组织空间和时间记忆。然而,视频预测学习仍然存在一个困境:越来越多的时间深度模型被设计用于捕捉复杂的变化,同时在梯度后向传播中引入更多困难。为了缓解这种不良影响,我们提出了一种梯度公路结构,该梯度公路结构为从输出到长程输入的梯度流提供了备选的较短路线。这种架构与因果LSTM无缝协作,使PredRNN ++能够自适应捕获短期和长期相关性。我们在合成和真实视频数据集上评估了我们的模型,显示了它能够缓解消失梯度问题,并即使在困难的物体遮挡情况下也能产生最先进的预测结果。
安装
所需的Python库: tensorflow (>=1.0) + opencv + numpy.
验证环境: ubuntu/centOS + nvidia titan X (Pascal) with cuda (>=8.0) and cudnn (>=5.0).
数据集
在三个视频数据库上做了实验: Moving Mnist, Human3.6M, KTH Actions.
对于视频格式的数据集,需要把原始视频中每一帧提取出来,然后移动到 data/
文件夹.
训练
Use the train.py script to train the model. 使用train.py脚本来训练模型,如果你要训练在Moving MNIST 数据集上缺省的模型,只需要简单运行下列代码:
python train.py
你需要改变--train_data_paths
,--valid_data_paths
并且--save_dir
指向的路径,他们分别代表训练数据路径,验证数据路径和保存checkpoints的路径。
要训练您自己的数据集,请查看文件夹中的InputHandle
类data_provider/
。你必须为你自己的数据集编写一个类似的迭代器对象。
推断时,算法生成的未来视频帧将保存在--results
文件夹中。
预测样例
下面的动画展示了PredRNN++算法的视频预测结果,画面中有三列,分别为:
第一列是 ground truth | 第二列是PredRNN++预测结果 | 第三列是baseline model.
共有20帧,对于第二列和第三列,前10帧是给定的,后10帧是预测出来的。
Citation
Please cite the following paper if you find this repository useful.
@inproceedings{wang2018predrnn,
title={PredRNN++: Towards A Resolution of the Deep-in-Time Dilemma in Spatiotemporal Predictive Learning},
author={Wang, Yunbo and Gao, zhifeng and Long, Mingsheng and Wang, Jianmin and Yu, Philip S.},
journal={ICML},
year={2018}
}
转载请注明:《PredRNN++:视频预测的循环神经网络》