Focal Loss for Object Detection: Idea
- The loss function is reshaped to down-weight easy examples and thus focus training on hard negatives. A modulating factor (1-pt)^ γ is added to the cross entropy loss where γ is tested from [0,5] in the experiment.
- There are two properties of the FL:
- When an example is misclassified and pt is small, the modulating factor is near 1 and the loss is unaffected. As pt →1, the factor goes to 0 and the loss for well-classified examples is down-weighted.
- The focusing parameter γ smoothly adjusts the rate at which easy examples are down-weighted. When γ = 0, FL is equivalent to CE. When γ is increased, the effect of the modulating factor is likewise increased. (γ=2 works best in experiment.)
- For instance, with γ = 2, an example classified with pt = 0.9 would have 100 lower loss compared with CE and with pt = 0.968 it would have 1000 lower loss. This in turn increases the importance of correcting misclassified examples.
- The loss is scaled down by at most 4× for pt ≤ 0.5 and γ = 2.
Focal Loss: CUDA Kernel
template <typename scalar_t>
__global__ void SigmoidFocalLossForward(const int nthreads,
const scalar_t *logits,
const long *targets,
const int num_classes,
const float gamma, const float alpha,
const int num, scalar_t *losses) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80];
// Decide it is positive or negative case.
scalar_t c1 = (t == (d + 1));
scalar_t c2 = (t >= 0 & t != (d + 1));
scalar_t zn = (1.0 - alpha);
scalar_t zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
scalar_t p = 1. / (1. + expf(-logits[i]));
// (1-p)**gamma * log(p) where
scalar_t term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
// p**gamma * log(1-p)
scalar_t term2 =
powf(p, gamma) *
(-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
losses[i] = 0.0;
losses[i] += -c1 * term1 * zp;
losses[i] += -c2 * term2 * zn;
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossForward
template <typename scalar_t>
__global__ void SigmoidFocalLossBackward(
const int nthreads, const scalar_t *logits, const long *targets,
const scalar_t *d_losses, const int num_classes, const float gamma,
const float alpha, const int num, scalar_t *d_logits) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80], 0 is background;
// Decide it is positive or negative case.
scalar_t c1 = (t == (d + 1));
scalar_t c2 = (t >= 0 & t != (d + 1));
scalar_t zn = (1.0 - alpha);
scalar_t zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
scalar_t p = 1. / (1. + expf(-logits[i]));
// (1-p)**g * (1 - p - g*p*log(p)
scalar_t term1 =
powf((1. - p), gamma) * (1. - p - (p * gamma * logf(max(p, FLT_MIN))));
// (p**g) * (g*(1-p)*log(1-p) - p)
scalar_t term2 =
powf(p, gamma) *
((-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
(1. - p) * gamma -
p);
d_logits[i] = 0.0;
d_logits[i] += -c1 * term1 * zp;
d_logits[i] += -c2 * term2 * zn;
d_logits[i] = d_logits[i] * d_losses[i];
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossBackward
Focal Loss: PyTorch Wrapper
class SigmoidFocalLossFunction(Function):
@staticmethod
def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'):
ctx.save_for_backward(input, target)
num_classes = input.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes,
gamma, alpha)
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
input, target = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_input = sigmoid_focal_loss_cuda.backward(input, target, d_loss,
num_classes, gamma, alpha)
return d_input, None, None, None, None
Code from mmdetection
Related
- Focal Loss in Object Detection: PyTorch Implementation(with CUDA)
- Deformable Convolution in Object Detection: PyTorch Implementation(with CUDA)
- (Soft)NMS in Object Detection: PyTorch Implementation(with CUDA)
- FPN for Object Detection: PyTorch Implementation
- RoIPooling in Object Detection: PyTorch Implementation(with CUDA)
- From Classification to Panoptic Segmentation: 7 years of Visual Understanding with Deep Learning
- Convolutional Neural Network Must Reads: Xception, ShuffleNet, ResNeXt and DenseNet
- Object Detection Must Reads(1): Fast RCNN, Faster RCNN, R-FCN and FPN
- Object Detection Must Reads(2): YOLO, YOLO9000, and RetinaNet
-
Object Detection Must Reads(3): SNIP, SNIPER, OHEM, and DSOD
-
Anchor-Free Object Detection(Part 1): CornerNet, CornerNet-Lite, ExtremeNet, CenterNet
- Anchor-Free Object Detection(Part 2): FSAF, FoveaBox, FCOS, RepPoints