containment threshold

pull/572/head
Pete Matsyburka 4 weeks ago
parent f1db7c7e82
commit 7df729b8b5

@ -297,7 +297,7 @@ module Templates
[img_array.reshape(1, 3, resolution, resolution), transform_info] [img_array.reshape(1, 3, resolution, resolution), transform_info]
end end
def nms(boxes, scores, iou_threshold = 0.5) def nms(boxes, scores, iou_threshold = 0.5, containment_threshold = 0.7)
return Numo::Int32[] if boxes.shape[0].zero? return Numo::Int32[] if boxes.shape[0].zero?
x1 = boxes[true, 0] x1 = boxes[true, 0]
@ -328,7 +328,11 @@ module Templates
iou = intersection / (areas[i] + areas[order[1..]] - intersection) iou = intersection / (areas[i] + areas[order[1..]] - intersection)
inds = iou.le(iou_threshold).where other_areas = areas[order[1..]]
containment = intersection / (other_areas + 1e-6)
suppress_mask = iou.gt(iou_threshold) | containment.gt(containment_threshold)
inds = suppress_mask.eq(0).where
order = order[inds + 1] order = order[inds + 1]
end end

Loading…
Cancel
Save