Focal Loss in Object Detection: PyTorch Implementation(with CUDA)

 

Focal Loss for Object Detection: Idea

Focal Loss PyTorch

  • The loss function is reshaped to down-weight easy examples and thus focus training on hard negatives. A modulating factor (1-pt)^ γ is added to the cross entropy loss where γ is tested from [0,5] in the experiment.
  • There are two properties of the FL:
  1. When an example is misclassified and pt is small, the modulating factor is near 1 and the loss is unaffected. As pt →1, the factor goes to 0 and the loss for well-classified examples is down-weighted.
  2. The focusing parameter γ smoothly adjusts the rate at which easy examples are down-weighted. When γ = 0, FL is equivalent to CE. When γ is increased, the effect of the modulating factor is likewise increased. (γ=2 works best in experiment.)
  • For instance, with γ = 2, an example classified with pt = 0.9 would have 100 lower loss compared with CE and with pt = 0.968 it would have 1000 lower loss. This in turn increases the importance of correcting misclassified examples.
  • The loss is scaled down by at most 4× for pt ≤ 0.5 and γ = 2.

Focal Loss: CUDA Kernel

template <typename scalar_t>
__global__ void SigmoidFocalLossForward(const int nthreads,
                                        const scalar_t *logits,
                                        const long *targets,
                                        const int num_classes,
                                        const float gamma, const float alpha,
                                        const int num, scalar_t *losses) {
  CUDA_1D_KERNEL_LOOP(i, nthreads) {
    int n = i / num_classes;
    int d = i % num_classes;  // current class[0~79];
    int t = targets[n];       // target class [1~80];

    // Decide it is positive or negative case.
    scalar_t c1 = (t == (d + 1));
    scalar_t c2 = (t >= 0 & t != (d + 1));

    scalar_t zn = (1.0 - alpha);
    scalar_t zp = (alpha);

    // p = 1. / 1. + expf(-x); p = sigmoid(x)
    scalar_t p = 1. / (1. + expf(-logits[i]));

    // (1-p)**gamma * log(p) where
    scalar_t term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));

    // p**gamma * log(1-p)
    scalar_t term2 =
        powf(p, gamma) *
        (-1. * logits[i] * (logits[i] >= 0) -
         logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));

    losses[i] = 0.0;
    losses[i] += -c1 * term1 * zp;
    losses[i] += -c2 * term2 * zn;

  }  // CUDA_1D_KERNEL_LOOP
}  // SigmoidFocalLossForward

template <typename scalar_t>
__global__ void SigmoidFocalLossBackward(
    const int nthreads, const scalar_t *logits, const long *targets,
    const scalar_t *d_losses, const int num_classes, const float gamma,
    const float alpha, const int num, scalar_t *d_logits) {
  CUDA_1D_KERNEL_LOOP(i, nthreads) {
    int n = i / num_classes;
    int d = i % num_classes;  // current class[0~79];
    int t = targets[n];       // target class [1~80], 0 is background;

    // Decide it is positive or negative case.
    scalar_t c1 = (t == (d + 1));
    scalar_t c2 = (t >= 0 & t != (d + 1));

    scalar_t zn = (1.0 - alpha);
    scalar_t zp = (alpha);
    // p = 1. / 1. + expf(-x); p = sigmoid(x)
    scalar_t p = 1. / (1. + expf(-logits[i]));

    // (1-p)**g * (1 - p - g*p*log(p)
    scalar_t term1 =
        powf((1. - p), gamma) * (1. - p - (p * gamma * logf(max(p, FLT_MIN))));

    // (p**g) * (g*(1-p)*log(1-p) - p)
    scalar_t term2 =
        powf(p, gamma) *
        ((-1. * logits[i] * (logits[i] >= 0) -
          logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
             (1. - p) * gamma -
         p);
    d_logits[i] = 0.0;
    d_logits[i] += -c1 * term1 * zp;
    d_logits[i] += -c2 * term2 * zn;
    d_logits[i] = d_logits[i] * d_losses[i];

  }  // CUDA_1D_KERNEL_LOOP
}  // SigmoidFocalLossBackward

Focal Loss: PyTorch Wrapper


class SigmoidFocalLossFunction(Function):

    @staticmethod
    def forward(ctx, input, target, gamma=2.0, alpha=0.25, reduction='mean'):
        ctx.save_for_backward(input, target)
        num_classes = input.shape[1]
        ctx.num_classes = num_classes
        ctx.gamma = gamma
        ctx.alpha = alpha

        loss = sigmoid_focal_loss_cuda.forward(input, target, num_classes,
                                               gamma, alpha)
        reduction_enum = F._Reduction.get_enum(reduction)
        # none: 0, mean:1, sum: 2
        if reduction_enum == 0:
            return loss
        elif reduction_enum == 1:
            return loss.mean()
        elif reduction_enum == 2:
            return loss.sum()

    @staticmethod
    @once_differentiable
    def backward(ctx, d_loss):
        input, target = ctx.saved_tensors
        num_classes = ctx.num_classes
        gamma = ctx.gamma
        alpha = ctx.alpha
        d_loss = d_loss.contiguous()
        d_input = sigmoid_focal_loss_cuda.backward(input, target, d_loss,
                                                   num_classes, gamma, alpha)
        return d_input, None, None, None, None

Code from mmdetection