当前位置 博文首页 > 忽逢桃林:自监督图像论文复现 | BYOL(pytorch)| 2020

    忽逢桃林:自监督图像论文复现 | BYOL(pytorch)| 2020

    作者:忽逢桃林 时间:2021-01-30 19:23

    继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:
    https://juejin.cn/post/6922347006144970760

    现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。
    github:https://github.com/lucidrains/byol-pytorch

    【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。

    主要模型代码

    class BYOL(nn.Module):
        def __init__(
            self,
            net,
            image_size,
            hidden_layer = -2,
            projection_size = 256,
            projection_hidden_size = 4096,
            augment_fn = None,
            augment_fn2 = None,
            moving_average_decay = 0.99,
            use_momentum = True
        ):
            super().__init__()
            self.net = net
    
            # default SimCLR augmentation
    
            DEFAULT_AUG = torch.nn.Sequential(
                RandomApply(
                    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                    p = 0.3
                ),
                T.RandomGrayscale(p=0.2),
                T.RandomHorizontalFlip(),
                RandomApply(
                    T.GaussianBlur((3, 3), (1.0, 2.0)),
                    p = 0.2
                ),
                T.RandomResizedCrop((image_size, image_size)),
                T.Normalize(
                    mean=torch.tensor([0.485, 0.456, 0.406]),
                    std=torch.tensor([0.229, 0.224, 0.225])),
            )
    
            self.augment1 = default(augment_fn, DEFAULT_AUG)
            self.augment2 = default(augment_fn2, self.augment1)
    
            self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
    
            self.use_momentum = use_momentum
            self.target_encoder = None
            self.target_ema_updater = EMA(moving_average_decay)
    
            self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
    
            # get device of network and make wrapper same device
            device = get_module_device(net)
            self.to(device)
    
            # send a mock image tensor to instantiate singleton parameters
            self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    
        @singleton('target_encoder')
        def _get_target_encoder(self):
            target_encoder = copy.deepcopy(self.online_encoder)
            set_requires_grad(target_encoder, False)
            return target_encoder
    
        def reset_moving_average(self):
            del self.target_encoder
            self.target_encoder = None
    
        def update_moving_average(self):
            assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
            assert self.target_encoder is not None, 'target encoder has not been created yet'
            update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
    
        def forward(self, x, return_embedding = False):
            if return_embedding:
                return self.online_encoder(x)
    
            image_one, image_two = self.augment1(x), self.augment2(x)
    
            online_proj_one, _ = self.online_encoder(image_one)
            online_proj_two, _ = self.online_encoder(image_two)
    
            online_pred_one = self.online_predictor(online_proj_one)
            online_pred_two = self.online_predictor(online_proj_two)
    
            with torch.no_grad():
                target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
                target_proj_one, _ = target_encoder(image_one)
                target_proj_two, _ = target_encoder(image_two)
                target_proj_one.detach_()
                target_proj_two.detach_()
    
            loss_one = loss_fn(online_pred_one, target_proj_two.detach())
            loss_two = loss_fn(online_pred_two, target_proj_one.detach())
    
            loss = loss_one + loss_two
            return loss.mean()
    
    • 先看forward()函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss
    • 如果是推理过程,那么return_embedding=True,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector
    • 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
    • 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
    • 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
    • 如果self.use_momentum=False,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。
    • 最后计算通过loss_fn计算损失,然后return loss.mean()

    所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:

    • online_encoder如何定义?
    • predictor如何定义?
    • 图像增强方法如何定义?
    • loss_fn损失函数如何定义?

    augment

    从上面的代码中可以看到这一段:

    # default SimCLR augmentation
    
            DEFAULT_AUG = torch.nn.Sequential(
                RandomApply(
                    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                    p = 0.3
                ),
                T.RandomGrayscale(p=0.2),
                T.RandomHorizontalFlip(),
                RandomApply(
                    T.GaussianBlur((3, 3), (1.0, 2.0)),
                    p = 0.2
                ),
                T.RandomResizedCrop((image_size, image_size)),
                T.Normalize(
                    mean=torch.tensor([0.485, 0.456, 0.406]),
                    std=torch.tensor([0.229, 0.224, 0.225])),
            )
    
            self.augment1 = default(augment_fn, DEFAULT_AUG)
            self.augment2 = default(augment_fn2, self.augment1)
    

    可以看到:

    • 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
    • from torchvision import transforms as T

    比较陌生的可能就是torchvision.transforms.ColorJitter()这个方法了。

    从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调

    encoder+projector

    class NetWrapper(nn.Module):
        def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
            super().__init__()
            self.net = net
            self.layer = layer
    
            self.projector = None
            self.projection_size = projection_size
            self.projection_hidden_size = projection_hidden_size
    
            self.hidden = None
            self.hook_registered = False
    
        def _find_layer(self):
            if type(self.layer) == str:
                modules = dict([*self.net.named_modules()])
                return modules.get(self.layer, None)
            elif type(self.layer) == int:
                children = [*self.net.children()]
                return children[self.layer]
            return None
    
        def _hook(self, _, __, output):
            self.hidden = flatten(output)
    
        def _register_hook(self):
            layer = self._find_layer()
            assert layer is not None, f'hidden layer ({self.layer}) not found'
            handle = layer.register_forward_hook(self._hook)
            self.hook_registered = True
    
        @singleton('projector')
        def _get_projector(self, hidden):
            _, dim = hidden.shape
            projector = MLP(dim, self.projection_size, self.projection_hidden_size)
            return projector.to(hidden)
    
        def get_representation(self, x):
            if self.layer == -1:
                return self.net(x)
    
            if not self.hook_registered:
                self._register_hook()
    
            _ = self.net(x)
            hidden = self.hidden
            self.hidden = None
            assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
            return hidden
    
        def forward(self, x, return_embedding = False):
            representation = self.get_representation(x)
    
            if return_embedding:
                return representation
    
            projector = self._get_projector(representation)
            projection = projector(representation)
            return projection, representation
    

    这个就是基本的encoder+projector,里面包含encoder和projector。

    encoder

    这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:

    from torchvision import models, transforms
    resnet = models.resnet50(pretrained=True)
    

    所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。

    projector

    调用到了MLP这个东西:

    class MLP(nn.Module):
        def __init__(self, dim, projection_size, hidden_size = 4096):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_size, projection_size)
            )
    
        def forward(self, x):
            return self.net(x)
    

    是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。

    predictor

    self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
    

    这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size

    这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设

    loss_fn

    def loss_fn(x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)
    

    这部分和论文中一致。

    综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。

    bk
    下一篇:没有了