Info

Title: TestTime Training for OutofDistribution Generalization

Author: Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt

Date: Sep. 2019

Arxiv: 1909.13231
Highlights
 Theoretically prove of benefits on testtime selfsupervised training for convex loss functions, empirically for nonconvex functions.
Abstract
We introduce a general approach, called testtime training, for improving the performance of predictive models when test and training data come from different distributions. Testtime training turns a single unlabeled test instance into a selfsupervised learning problem, on which we update the model parameters before making a prediction on the test sample. We show that this simple idea leads to surprising improvements on diverse image classification benchmarks aimed at evaluating robustness to distribution shifts. Theoretical investigations on a convex model reveal helpful intuitions for when we can expect our approach to help.
Motivation & Design
Method Overview
The route we take is to create a selfsupervised learning problem based only on this single test case x, which we use to update θ at testtime before we then make a prediciton on x. Selfsupervised learning uses an auxiliary task that automatically creates labels from unlabeled data. For the visual data we work with, the task rotates an image x by a multiple of 90 degrees, and assigns the angle as the label.
Theorem(for convex function)
Empirical Results(for nonconvex function)
The Corruption Benchmark
Hendrycks & Dietterich (2019) propose to benchmark robustness of neural networks on 15 types of corruptions from four broad categories: noise, blur, weather and digital. Each corruption type comes in five levels of severity, with level 5 the most severe. The corruptions are algorithmically simulated to mimic realworld corruptions as much as possible on copies of the test set for both CIFAR10 and ImageNet.
Scatter plot of the inner product between the gradients (on the shared feature extractor θe ) of the main task $l_m$ and the selfsupervised task $l_e$ , and the improvement in test error (%) from testtime training, for the standard (left) and online (right) version. Each point is the average over a test set, and each scatter plot has 75 test sets, from all 15 types of corruptions over five levels. The blue lines and bands are the best linear fits and the 99% confidence intervals. The linear correlation coefficients are 0.93 and 0.89 respectively, indicating strong positive correlation between the two quantities, as suggested by Theorem 1.
Experiments & Ablation Study
Test Error on CIFAR10C, level 5
Test Accuracy on ImageNetC, level 5
The lower panel shows the accuracy of the online version as the average over a sliding window of 100 samples; testtime learning online generalizes better as more samples are tested on, without hurting on the original distribution. We use accuracy instead of error here because the baseline performance is very poor with most corruptions.
Code
TestTime Training with Rotation as Auxiliary Task
add an additional step after main task prediction
for batch_idx, (inputs, labels) in enumerate(trloader):
optimizer.zero_grad()
inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
outputs_cls = net(inputs_cls)
loss = criterion(outputs_cls, labels_cls)
inputs_ssh, labels_ssh = rotate_batch(inputs, args.rotation_type)
outputs_ssh = ssh(inputs_ssh)
loss_ssh = criterion(outputs_ssh, labels_ssh)
loss += loss_ssh
loss.backward()
optimizer.step()
Rotation Label Generation
def rotate_batch(batch, label):
if label == 'rand':
labels = torch.randint(4, (len(batch),), dtype=torch.long)
elif label == 'expand':
labels = torch.cat([torch.zeros(len(batch), dtype=torch.long),
torch.zeros(len(batch), dtype=torch.long) + 1,
torch.zeros(len(batch), dtype=torch.long) + 2,
torch.zeros(len(batch), dtype=torch.long) + 3])
batch = batch.repeat((4,1,1,1))
else:
assert isinstance(label, int)
labels = torch.zeros((len(batch),), dtype=torch.long) + label
return rotate_batch_with_labels(batch, labels), labels
SelfSupervised Head(SSH) Model Architecture
Extracting head from backbone
net = ResNet(args.depth, args.width, channels=3, classes=classes, norm_layer=norm_layer).cuda()
ext = extractor_from_layer2(net)
head = head_on_layer2(net, args.width, 4)
ssh = ExtractorHead(ext, head)
class ExtractorHead(nn.Module):
def __init__(self, ext, head):
super(ExtractorHead, self).__init__()
self.ext = ext
self.head = head
def forward(self, x):
return self.head(self.ext(x))
def extractor_from_layer2(net):
layers = [net.conv1, net.layer1, net.layer2]
return nn.Sequential(*layers)
def head_on_layer2(net, width, classes):
head = copy.deepcopy([net.layer3, net.bn, net.relu, net.avgpool])
head.append(ViewFlatten())
head.append(nn.Linear(64 * width, classes))
return nn.Sequential(*head)