diff --git a/lib/templates/image_to_fields.rb b/lib/templates/image_to_fields.rb index f2bb7244..c8d86473 100755 --- a/lib/templates/image_to_fields.rb +++ b/lib/templates/image_to_fields.rb @@ -297,7 +297,7 @@ module Templates [img_array.reshape(1, 3, resolution, resolution), transform_info] end - def nms(boxes, scores, iou_threshold = 0.5) + def nms(boxes, scores, iou_threshold = 0.5, containment_threshold = 0.7) return Numo::Int32[] if boxes.shape[0].zero? x1 = boxes[true, 0] @@ -328,7 +328,11 @@ module Templates iou = intersection / (areas[i] + areas[order[1..]] - intersection) - inds = iou.le(iou_threshold).where + other_areas = areas[order[1..]] + containment = intersection / (other_areas + 1e-6) + + suppress_mask = iou.gt(iou_threshold) | containment.gt(containment_threshold) + inds = suppress_mask.eq(0).where order = order[inds + 1] end