对抗学习新进展:MIT和微软联合出品“元对抗扰动”

作者 | 孙裕道
编辑 | CV君
报道 | 我爱计算机视觉(微信id:aicvml)

引言

元学习是过去几年[公式]领域非常热门的的学习方法之一,各种研究工作都是基于元学习展开的。元学习的目标是使得分类模型能够获取一种学会学习调参的能力,使得模型可以在获取已有知识的基础上能够自适应快速学习新的任务。该论文是将元学习应用到对抗扰动中,作者提出一种元对抗扰动生成方法,该方法是一种更好的对抗扰动初始化的方法。

元对抗扰动可以使得干净样本图像在仅通过一步梯度上升更新后以高概率使得模型错误分类。实验结果表明,各种先进的神经网络模型容易受到元对抗扰动的攻击,而且元对抗扰动具有很好的可迁移性,并能很好地推广到不可见的数据样本点和不同的神经网络模型中。

详细信息如下:

论文链接:arxiv.org/abs/2111.1029

论文方法

受元学习方法[公式]的启发,该论文的主要目的是训练一种元对抗扰动,该扰动可以在一步或几步更新内对新数据点进行更有效的攻击。

给定随机初始化的元对抗扰动[公式],参数为[公式]的目标分类器[公式],交叉熵损失函数[公式],数据集[公式]。在数据集[公式]中采样出一批数据[公式],通过梯度上升法生成新的对抗扰动[公式]。作者的目标是找到一个单一的元对抗扰动[公式],以便新的样本数据在几次迭代后能够以高概率使得模型分类器分类出错,具体的优化形式如下所示:

[公式]

该论文的具体步骤如下,首先使用单步或多步梯度上升法去计算新数据点的对抗扰动[公式],具体的计算公式如下所示:

[公式]

其中步长[公式]是一个超参数。更精确的数学表述如下所示:

[公式]

元对抗扰动非常重视自适应性,接下来利用以上生成的扰动[公式]在新的样本数据[公式]上去更新扰动[公式],具体的公式如下所示:

[公式]

其中[公式]是元步长。最后将得到的元对抗扰动投影到可行空间中。为了能够更直观的理解论文算法的细节,做了如下[公式]算法示意图:

这里需要注意的是,在该论文中求解梯度[公式]时,遇到了元学习学习初始化参数方法[公式]一样的问题,即求解梯度的过程中会涉及到[公式]矩阵的计算。

梯度补充计算过程如下所示:[公式]

如上公式所示,[公式]是一个[公式]矩阵。要知道[公式]矩阵的计算会往往会涉及到大量的计算成本,作者在本文中没有明确给出这个问题应对方法。在实际编程的时候,我采用的是[公式]的简化计算量的方法,在下面的程序代码里会有所体现。

整个元对抗扰动[公式]算法流程图如下所示:

实验结果

下表为元对抗扰动[公式][公式][公式]三种方法的对比结果。可以发现[公式]要显著优于其它两种方法。 对于所有网络,[公式]的改善率约为10-20%。

下表总结了[公式]在七个模型之间的可迁移性。对于每个分类模型,计算一个元对抗扰动,并显示该扰动迁移到所有其他模型的精度,同时对目标模型进行一步更新,其中在最下面一行显示了在初始化时不使用[公式]的精度。实验结果可以发现,[公式]生成的对抗扰动在模型之间具有很好的迁移性。

如下图所示,数据集的规模越大,性能越好。另外,即使只使用10幅图像计算元对抗扰动,这种扰动仍然会导致模型的分类精确度比原始[公式]下降15%左右。这说明了元对抗扰动在看不见的数据点上具有惊人的泛化能力,可以在非常小的训练数据集上进行计算。

程序代码

论文中没有给出元对抗扰动[公式]的具体代码,以下是自己根据[公式]算法流程图以及对[公式]矩阵如何计算的理解编写的简化的完整程序。

import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
import os
import torch.nn.functional as F
import random

def generate_dataset(sample_num, class_num, X_shape):
    Label_list = []
    Sample_list = []
    for i in range(sample_num):
        y = np.random.randint(0, class_num)
        Label_list.append(y)
        Sample_list.append(np.random.normal(y, 0.2, X_shape))
    return Sample_list, Label_list

def Sample_dataset(numpy_dataset, batch_size):
    index_list = random.sample(range(0, len(numpy_dataset[0])), batch_size)
    data_list = []
    label_list = []
    for index in index_list:
        data_list.append(numpy_dataset[0][index])
        label_list.append(numpy_dataset[1][index])
    return torch.tensor(data_list).to(torch.float32), torch.tensor(label_list).to(torch.int64)


class Normal_Dataset(Data.Dataset):
    def __init__(self, Numpy_Dataset):
        super(Normal_Dataset, self).__init__()
        self.data_tensor = torch.tensor(Numpy_Dataset[0]).to(torch.float32)
        self.target_tensor = torch.tensor(Numpy_Dataset[1]).to(torch.int64)

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)

class Classifer(nn.Module):
    def __init__(self):
        super(Classifer, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 10, kernel_size = 9)  # 10, 36x36
        self.conv2 = nn.Conv2d(in_channels = 10, out_channels = 20, kernel_size = 17 ) # 20, 20x20
        self.fc1 = nn.Linear(20*20*20, 512)
        self.fc2 = nn.Linear(512, 7)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)
        out = F.relu(self.conv2(out))
        out = out.view(in_size, -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        out = F.softmax(out, dim=1)
        return out


def MAP(DataLoader, alpha, beta, model, loss_fn, PI_epsilon,input_shape, epoches, numpy_dataset):
    v = torch.zeros(input_shape)  # [3, 44, 44]
    for epoch in range(epoches):
        for batch_x, batch_y in DataLoader:
            v.requires_grad = True
            # Evaluate nable_vL(f_theta) using B with v
            outputs = model(batch_x + v)     # [2, 3, 44, 44] + [3, 44, 44] = [2, 3, 44, 44]
            loss = loss_fn(outputs, batch_y)
            loss.backward()
            # Compute v_prime
            v_prime = (v + alpha * v.grad.data).detach_()
            v_prime.requires_grad = True
            # Sample B_prime dataset
            batch_x_prime, batch_y_prime = Sample_dataset(numpy_dataset, 2)
            # Evaluate nable_vL(f_theta) using B_prime with v_prime
            outputs_prime = model(batch_x_prime + v_prime)
            loss_prime = loss_fn(outputs_prime, batch_y_prime)
            loss_prime.backward()
            # Update v
            v = (v + beta * v_prime.grad).detach()
    return v

if __name__ == '__main__':
    input_shape = (3,44,44)
    numpy_dataset = generate_dataset(100, 7, input_shape)
    Dataset = Normal_Dataset(numpy_dataset)
    DataLoader = Data.DataLoader(
                        dataset = Dataset,
                        batch_size = 2,
                        shuffle = True,
                        num_workers = 0,
                        )
    model = Classifer()
    loss_fn = nn.CrossEntropyLoss()
    alpha = 0.01
    belta = 0.01
    epoches = 1
    MAP(DataLoader, alpha, belta, model, loss_fn, 'PI',
        input_shape, epoches, numpy_dataset)

转载请注明:《对抗学习新进展:MIT和微软联合出品“元对抗扰动”