pull/572/head
Pete Matsyburka 2 months ago
parent b3eb8a70bf
commit 4375459856

@ -48,6 +48,9 @@ Style/MultipleComparison:
Style/NumericPredicate:
Enabled: false
Style/MinMaxComparison:
Enabled: false
Naming/PredicateMethod:
Enabled: false

@ -60,24 +60,24 @@ module Templates
# rubocop:disable Metrics, Style
def call(io, attachment: nil, confidence: 0.3, temperature: 1, inference: Templates::ImageToFields, nms: 0.1,
split_page: false, aspect_ratio: true, padding: 20, regexp_type: true, &)
nmm: 0.5, split_page: false, aspect_ratio: true, padding: 20, regexp_type: true, &)
fields, head_node =
if attachment&.image?
process_image_attachment(io, attachment:, confidence:, nms:, split_page:, inference:,
process_image_attachment(io, attachment:, confidence:, nms:, nmm:, split_page:, inference:,
temperature:, aspect_ratio:, padding:, &)
else
process_pdf_attachment(io, attachment:, confidence:, nms:, split_page:, inference:,
process_pdf_attachment(io, attachment:, confidence:, nms:, nmm:, split_page:, inference:,
temperature:, aspect_ratio:, regexp_type:, padding:, &)
end
[fields, head_node]
end
def process_image_attachment(io, attachment:, confidence:, nms:, temperature:, inference:,
def process_image_attachment(io, attachment:, confidence:, nms:, nmm:, temperature:, inference:,
split_page: false, aspect_ratio: false, padding: nil)
image = Vips::Image.new_from_buffer(io.read, '')
fields = inference.call(image, confidence:, nms:, split_page:,
fields = inference.call(image, confidence:, nms:, nmm:, split_page:,
temperature:, aspect_ratio:, padding:)
fields = sort_fields(fields, y_threshold: 10.0 / image.height)
@ -104,7 +104,7 @@ module Templates
[fields, nil]
end
def process_pdf_attachment(io, attachment:, confidence:, nms:, temperature:, inference:,
def process_pdf_attachment(io, attachment:, confidence:, nms:, nmm:, temperature:, inference:,
split_page: false, aspect_ratio: false, padding: nil, regexp_type: false)
doc = Pdfium::Document.open_bytes(io.read)
@ -121,7 +121,7 @@ module Templates
image = Vips::Image.new_from_memory(data, width, height, 4, :uchar)
fields = inference.call(image, confidence: confidence / 3.0, nms:, split_page:,
fields = inference.call(image, confidence: confidence / 3.0, nms:, nmm:, split_page:,
temperature:, aspect_ratio:, padding:)
text_fields = extract_text_fields_from_page(page)

@ -26,7 +26,7 @@ module Templates
CPU_THREADS = Etc.nprocessors
# rubocop:disable Metrics
def call(image, confidence: 0.3, nms: 0.1, temperature: 1,
def call(image, confidence: 0.3, nms: 0.1, nmm: 0.9, temperature: 1,
split_page: false, aspect_ratio: true, padding: nil, resolution: self.resolution)
image = image.extract_band(0, n: 3) if image.bands > 3
@ -68,7 +68,7 @@ module Templates
detections = postprocess_outputs(boxes, logits, transform_info, confidence:, temperature:, resolution:)
end
detections = apply_nms(detections, nms)
detections = apply_nms_nmm(detections, nms_threshold: nms, nmm_threshold: nmm)
build_fields_from_detections(detections, image)
end
@ -297,7 +297,10 @@ module Templates
[img_array.reshape(1, 3, resolution, resolution), transform_info]
end
def nms(boxes, scores, iou_threshold = 0.5, containment_threshold = 0.7)
def nms(detections, iou_threshold = 0.5)
boxes = detections[:xyxy]
scores = detections[:confidence]
return Numo::Int32[] if boxes.shape[0].zero?
x1 = boxes[true, 0]
@ -328,16 +331,78 @@ module Templates
iou = intersection / (areas[i] + areas[order[1..]] - intersection)
other_areas = areas[order[1..]]
containment = intersection / (other_areas + 1e-6)
suppress_mask = iou.gt(iou_threshold) | containment.gt(containment_threshold)
inds = suppress_mask.eq(0).where
inds = iou.le(iou_threshold).where
order = order[inds + 1]
end
Numo::Int32.cast(keep)
{
xyxy: detections[:xyxy][keep, true],
confidence: detections[:confidence][keep],
class_id: detections[:class_id][keep]
}
end
def nmm(detections, overlap_threshold = 0.9, confidence: 0.3)
boxes = detections[:xyxy]
scores = detections[:confidence]
classes = detections[:class_id]
return detections if boxes.shape[0].zero?
x1 = boxes[true, 0]
y1 = boxes[true, 1]
x2 = boxes[true, 2]
y2 = boxes[true, 3]
areas = (x2 - x1) * (y2 - y1)
order = areas.sort_index.reverse
keep = []
while order.size.positive?
i = order[0]
keep << i
break if order.size == 1
xx1 = Numo::SFloat.maximum(x1[i], x1[order[1..]])
yy1 = Numo::SFloat.maximum(y1[i], y1[order[1..]])
xx2 = Numo::SFloat.minimum(x2[i], x2[order[1..]])
yy2 = Numo::SFloat.minimum(y2[i], y2[order[1..]])
w = Numo::SFloat.maximum(0.0, xx2 - xx1)
h = Numo::SFloat.maximum(0.0, yy2 - yy1)
intersection = w * h
overlap = intersection / areas[order[1..]]
merge_mask = scores[i] > confidence ? (overlap.gt(overlap_threshold) & classes[order[1..]].eq(classes[i])) : nil
if merge_mask && (merge_inds = merge_mask.where).size.positive?
candidates = order[merge_inds + 1]
scores[i] = [scores[i], scores[candidates].max].max
x1[i] = [x1[i], x1[candidates].min].min
y1[i] = [y1[i], y1[candidates].min].min
x2[i] = [x2[i], x2[candidates].max].max
y2[i] = [y2[i], y2[candidates].max].max
end
if merge_mask
inds = (~merge_mask).where
order = order[inds + 1]
else
order = order[1..]
end
end
{
xyxy: detections[:xyxy][keep, true],
confidence: detections[:confidence][keep],
class_id: detections[:class_id][keep]
}
end
def postprocess_outputs(boxes, logits, transform_info, detections = nil, confidence: 0.3, temperature: 1,
@ -433,16 +498,12 @@ module Templates
end
end
def apply_nms(detections, threshold = 0.5)
def apply_nms_nmm(detections, nms_threshold: 0.5, nmm_threshold: 0.7, confidence: 0.3)
return detections if detections[:xyxy].shape[0].zero?
keep_indices = nms(detections[:xyxy], detections[:confidence], threshold)
nms_result = nms(detections, nms_threshold)
{
xyxy: detections[:xyxy][keep_indices, true],
confidence: detections[:confidence][keep_indices],
class_id: detections[:class_id][keep_indices]
}
nmm(nms_result, nmm_threshold, confidence:)
end
def model

Loading…
Cancel
Save