|
|
|
|
@ -297,7 +297,7 @@ module Templates
|
|
|
|
|
[img_array.reshape(1, 3, resolution, resolution), transform_info]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
def nms(boxes, scores, iou_threshold = 0.5)
|
|
|
|
|
def nms(boxes, scores, iou_threshold = 0.5, containment_threshold = 0.7)
|
|
|
|
|
return Numo::Int32[] if boxes.shape[0].zero?
|
|
|
|
|
|
|
|
|
|
x1 = boxes[true, 0]
|
|
|
|
|
@ -328,7 +328,11 @@ module Templates
|
|
|
|
|
|
|
|
|
|
iou = intersection / (areas[i] + areas[order[1..]] - intersection)
|
|
|
|
|
|
|
|
|
|
inds = iou.le(iou_threshold).where
|
|
|
|
|
other_areas = areas[order[1..]]
|
|
|
|
|
containment = intersection / (other_areas + 1e-6)
|
|
|
|
|
|
|
|
|
|
suppress_mask = iou.gt(iou_threshold) | containment.gt(containment_threshold)
|
|
|
|
|
inds = suppress_mask.eq(0).where
|
|
|
|
|
|
|
|
|
|
order = order[inds + 1]
|
|
|
|
|
end
|
|
|
|
|
|