RoIPooling in Object Detection: PyTorch Implementation(with CUDA)

 

RoIPooling Explanation

Region of interest pooling (also known as RoI pooling) is an operation widely used in object detection tasks using convolutional neural networks. For example, to detect multiple cars and pedestrians in a single image. Its purpose is to perform max pooling on inputs of nonuniform sizes to obtain fixed-size feature maps (e.g. 7×7).

The layer takes two inputs:

  1. A fixed-size feature map obtained from a deep convolutional network with several convolutions and max pooling layers.
  2. An N x 5 matrix of representing a list of regions of interest, where N is a number of RoIs. The first column represents the image index and the remaining four are the coordinates of the top left and bottom right corners of the region.

For every region of interest from the input list, it takes a section of the input feature map that corresponds to it and scales it to some pre-defined size (e.g., 7×7). The scaling is done by:

  1. Dividing the region proposal into equal-sized sections (the number of which is the same as the dimension of the output)
  2. Finding the largest value in each section
  3. Copying these max values to the output buffer

RolPooling Example

The main idea of RoIPooling:

RoIPooling

RolPooling CUDA Kernel

CUDA KERNEL Forward:

  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_w;
    int ph = (index / pooled_w) % pooled_h;
    int c = (index / pooled_w / pooled_h) % channels;
    int n = index / pooled_w / pooled_h / channels;

    const scalar_t *offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    // calculate the roi region on feature maps
    scalar_t roi_x1 = offset_rois[1] * spatial_scale;
    scalar_t roi_y1 = offset_rois[2] * spatial_scale;
    scalar_t roi_x2 = (offset_rois[3] + 1) * spatial_scale;
    scalar_t roi_y2 = (offset_rois[4] + 1) * spatial_scale;

    // force malformed rois to be 1x1
    scalar_t roi_w = roi_x2 - roi_x1;
    scalar_t roi_h = roi_y2 - roi_y1;
    if (roi_w <= 0 || roi_h <= 0) continue;

    scalar_t bin_size_w = roi_w / static_cast<scalar_t>(pooled_w);
    scalar_t bin_size_h = roi_h / static_cast<scalar_t>(pooled_h);

    // the corresponding bin region
    int bin_x1 = floor(static_cast<scalar_t>(pw) * bin_size_w + roi_x1);
    int bin_y1 = floor(static_cast<scalar_t>(ph) * bin_size_h + roi_y1);
    int bin_x2 = ceil(static_cast<scalar_t>(pw + 1) * bin_size_w + roi_x1);
    int bin_y2 = ceil(static_cast<scalar_t>(ph + 1) * bin_size_h + roi_y1);

    // add roi offsets and clip to input boundaries
    bin_x1 = min(max(bin_x1, 0), width);
    bin_y1 = min(max(bin_y1, 0), height);
    bin_x2 = min(max(bin_x2, 0), width);
    bin_y2 = min(max(bin_y2, 0), height);
    bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1);

    // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
    int max_idx = -1;
    bottom_data += (roi_batch_ind * channels + c) * height * width;

    // Define an empty pooling region to be zero
    scalar_t max_val = is_empty ? static_cast<scalar_t>(0)
                                : bottom_data[bin_y1 * width + bin_x1] - 1;

    for (int h = bin_y1; h < bin_y2; ++h) {
      for (int w = bin_x1; w < bin_x2; ++w) {
        int offset = h * width + w;
        if (bottom_data[offset] > max_val) {
          max_val = bottom_data[offset];
          max_idx = offset;
        }
      }
    }
    top_data[index] = max_val;
    if (argmax_data != NULL) argmax_data[index] = max_idx;
  }

CUDA KERNEL Backward:

  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    int pw = index % pooled_w;
    int ph = (index / pooled_w) % pooled_h;
    int c = (index / pooled_w / pooled_h) % channels;
    int n = index / pooled_w / pooled_h / channels;

    int roi_batch_ind = rois[n * 5];
    int bottom_index = argmax_data[(n * channels + c) * pooled_h * pooled_w +
                                   ph * pooled_w + pw];

    atomicAdd(bottom_diff + (roi_batch_ind * channels + c) * height * width +
                  bottom_index,
              top_diff[index]);
  }

RoIPooling PyTorch Wrapper

PyTorch Wrapper:


class RoIPoolFunction(Function):

    @staticmethod
    def forward(ctx, features, rois, out_size, spatial_scale):
        if isinstance(out_size, int):
            out_h = out_size
            out_w = out_size
        elif isinstance(out_size, tuple):
            assert len(out_size) == 2
            assert isinstance(out_size[0], int)
            assert isinstance(out_size[1], int)
            out_h, out_w = out_size
        else:
            raise TypeError(
                '"out_size" must be an integer or tuple of integers')
        assert features.is_cuda
        ctx.save_for_backward(rois)
        num_channels = features.size(1)
        num_rois = rois.size(0)
        out_size = (num_rois, num_channels, out_h, out_w)
        output = features.new_zeros(out_size)
        argmax = features.new_zeros(out_size, dtype=torch.int)
        roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
                              output, argmax)
        ctx.spatial_scale = spatial_scale
        ctx.feature_size = features.size()
        ctx.argmax = argmax

        return output

    @staticmethod
    def backward(ctx, grad_output):
        assert grad_output.is_cuda
        spatial_scale = ctx.spatial_scale
        feature_size = ctx.feature_size
        argmax = ctx.argmax
        rois = ctx.saved_tensors[0]
        assert feature_size is not None

        grad_input = grad_rois = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.new_zeros(feature_size)
            roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
                                   spatial_scale, grad_input)

        return grad_input, grad_rois, None, None

Code from mmdetection