adjust detection

pull/402/head
Pete Matsyburka 2 weeks ago
parent e6280e8f5a
commit c1a5c91299

@ -49,7 +49,15 @@ class Pdfium
end end
end end
LineNode = Struct.new(:x, :y, :w, :h, :tilt) LineNode = Struct.new(:x, :y, :w, :h, :tilt) do
def endy
@endy ||= y + h
end
def endx
@endx ||= x + w
end
end
# rubocop:disable Naming/ClassAndModuleCamelCase # rubocop:disable Naming/ClassAndModuleCamelCase
class FPDF_LIBRARY_CONFIG < FFI::Struct class FPDF_LIBRARY_CONFIG < FFI::Struct
@ -480,7 +488,11 @@ class Pdfium
@text_nodes << TextNode.new(char, x, y, node_width, node_height) @text_nodes << TextNode.new(char, x, y, node_width, node_height)
end end
@text_nodes = @text_nodes.sort { |a, b| a.y == b.y ? a.x <=> b.x : a.y <=> b.y } y_threshold = 4.0 / width
@text_nodes = @text_nodes.sort do |a, b|
(a.endy - b.endy).abs < y_threshold ? a.x <=> b.x : a.endy <=> b.endy
end
ensure ensure
Pdfium.FPDFText_ClosePage(text_page) if text_page && !text_page.null? Pdfium.FPDFText_ClosePage(text_page) if text_page && !text_page.null?
end end
@ -549,7 +561,7 @@ class Pdfium
@line_nodes << LineNode.new(norm_x, norm_y, norm_w, norm_h, tilt) @line_nodes << LineNode.new(norm_x, norm_y, norm_w, norm_h, tilt)
end end
@line_nodes = @line_nodes.sort { |a, b| a.y == b.y ? a.x <=> b.x : a.y <=> b.y } @line_nodes = @line_nodes.sort { |a, b| a.endy == b.endy ? a.x <=> b.x : a.endy <=> b.endy }
end end
def close def close

@ -4,7 +4,16 @@ module Templates
module DetectFields module DetectFields
module_function module_function
TextFieldBox = Struct.new(:x, :y, :w, :h, keyword_init: true) TextFieldBox = Struct.new(:x, :y, :w, :h, keyword_init: true) do
def endy
@endy ||= y + h
end
def endx
@endx ||= x + w
end
end
PageNode = Struct.new(:prev, :next, :elem, :page, :attachment_uuid, keyword_init: true) PageNode = Struct.new(:prev, :next, :elem, :page, :attachment_uuid, keyword_init: true)
DATE_REGEXP = / DATE_REGEXP = /
@ -49,6 +58,9 @@ module Templates
\s*[:-]?\s*\z \s*[:-]?\s*\z
/ix /ix
LINEBREAK = ["\n", "\r"].freeze
CHECBOXES = ['☐', '□'].freeze
# rubocop:disable Metrics, Style # rubocop:disable Metrics, Style
def call(io, attachment: nil, confidence: 0.3, temperature: 1, inference: Templates::ImageToFields, def call(io, attachment: nil, confidence: 0.3, temperature: 1, inference: Templates::ImageToFields,
nms: 0.1, split_page: false, aspect_ratio: true, padding: 20, regexp_type: true, &) nms: 0.1, split_page: false, aspect_ratio: true, padding: 20, regexp_type: true, &)
@ -71,6 +83,8 @@ module Templates
fields = inference.call(image, confidence:, nms:, split_page:, fields = inference.call(image, confidence:, nms:, split_page:,
temperature:, aspect_ratio:, padding:) temperature:, aspect_ratio:, padding:)
fields = sort_fields(fields, y_threshold: 10.0 / image.height)
fields = fields.map do |f| fields = fields.map do |f|
{ {
uuid: SecureRandom.uuid, uuid: SecureRandom.uuid,
@ -113,6 +127,8 @@ module Templates
text_fields = extract_text_fields_from_page(page) text_fields = extract_text_fields_from_page(page)
line_fields = extract_line_fields_from_page(page) line_fields = extract_line_fields_from_page(page)
fields = sort_fields(fields, y_threshold: 10.0 / page.height)
fields = increase_confidence_for_overlapping_fields(fields, text_fields) fields = increase_confidence_for_overlapping_fields(fields, text_fields)
fields = increase_confidence_for_overlapping_fields(fields, line_fields) fields = increase_confidence_for_overlapping_fields(fields, line_fields)
@ -153,6 +169,12 @@ module Templates
doc.close doc.close
end end
def sort_fields(fields, y_threshold: 0.01)
fields.sort do |a, b|
(a.endy - b.endy).abs < y_threshold ? a.x <=> b.x : a.endy <=> b.endy
end
end
def print_debug(head_node) def print_debug(head_node)
current_node = head_node current_node = head_node
index = 0 index = 0
@ -189,121 +211,68 @@ module Templates
def build_page_nodes(page, fields, tail_node, attachment_uuid: nil) def build_page_nodes(page, fields, tail_node, attachment_uuid: nil)
field_nodes = [] field_nodes = []
current_text = ''.b
text_nodes = page.text_nodes
text_idx = 0 y_theshold = 10.0 / page.height
field_idx = 0
while text_idx < text_nodes.length || field_idx < fields.length text_nodes = page.text_nodes
text_node = text_nodes[text_idx]
field = fields[field_idx]
process_text_node = false current_field = fields.shift
process_field_node = false
if text_node && field index = 0
text_y_center = text_node.y + (text_node.h / 2.0)
field_y_center = field.y + (field.h / 2.0)
y_threshold = text_node.h / 2.0
vertical_distance = (text_y_center - field_y_center).abs
if vertical_distance < y_threshold prev_node = nil
is_underscore = text_node.content == '_'
is_left_of_field = text_node.x < field.x
if is_underscore && is_left_of_field loop do
text_x_end = text_node.x + text_node.w node = text_nodes[index]
distance = field.x - text_x_end break unless node
proximity_threshold = text_node.w * 3.0
if distance < proximity_threshold loop do
process_field_node = true break unless current_field
else
process_text_node = true
end
elsif is_left_of_field if ((current_field.endy - node.endy).abs < y_theshold &&
process_text_node = true (current_field.x <= node.x || node.content.in?(LINEBREAK))) ||
else current_field.endy < node.y
process_field_node = true field_node = PageNode.new(prev: tail_node, elem: current_field, page: page.page_index, attachment_uuid:)
end tail_node.next = field_node
tail_node = field_node
field_nodes << tail_node
elsif text_node.y < field.y current_field = fields.shift
process_text_node = true
else else
process_field_node = true break
end end
elsif text_node
process_text_node = true
elsif field
process_field_node = true
end end
if process_field_node if tail_node.elem.is_a?(Templates::ImageToFields::Field)
unless current_text.empty? text_node = PageNode.new(prev: tail_node, elem: ''.b, page: page.page_index, attachment_uuid:)
new_text_node = PageNode.new(prev: tail_node, elem: current_text, page: page.page_index, attachment_uuid:) tail_node.next = text_node
tail_node.next = new_text_node
tail_node = new_text_node
current_text = ''.b
end
new_field_node = PageNode.new(prev: tail_node, elem: field, page: page.page_index, attachment_uuid:) tail_node = text_node
tail_node.next = new_field_node end
tail_node = new_field_node
field_nodes << tail_node
while text_idx < text_nodes.length
text_node_to_check = text_nodes[text_idx]
is_part_of_field = false
if text_node_to_check.content == '_'
check_y_center = text_node_to_check.y + (text_node_to_check.h / 2.0)
check_y_dist = (check_y_center - field_y_center).abs
check_y_thresh = text_node_to_check.h / 2.0
if check_y_dist < check_y_thresh
padding = text_node_to_check.w * 3.0
field_x_start = field.x - padding
field_x_end = field.x + field.w + padding
text_x_start = text_node_to_check.x
text_x_end = text_node_to_check.x + text_node_to_check.w
is_part_of_field = true if text_x_start <= field_x_end && field_x_start <= text_x_end
end
end
break unless is_part_of_field
text_idx += 1 if prev_node && (node.endy - prev_node.endy) > y_theshold && LINEBREAK.exclude?(prev_node.content)
end tail_node.elem << "\n"
end
field_idx += 1 if node.content != '_' || !tail_node.elem.ends_with?('___')
elsif process_text_node tail_node.elem << node.content unless CHECBOXES.include?(node.content)
if text_idx > 0 end
prev_text_node = text_nodes[text_idx - 1]
x_gap = text_node.x - (prev_text_node.x + prev_text_node.w) prev_node = node
gap_w = text_node.w > prev_text_node.w ? text_node.w : prev_text_node.w index += 1
end
current_text << ' ' if x_gap > gap_w * 2 loop do
end break unless current_field
current_text << text_node.content field_node = PageNode.new(prev: tail_node, elem: current_field, page: page.page_index, attachment_uuid:)
text_idx += 1 tail_node.next = field_node
end tail_node = field_node
end field_nodes << tail_node
unless current_text.empty? current_field = fields.shift
new_text_node = PageNode.new(prev: tail_node, elem: current_text, page: page.page_index, attachment_uuid:)
tail_node.next = new_text_node
tail_node = new_text_node
end end
[field_nodes, tail_node] [field_nodes, tail_node]
@ -399,8 +368,8 @@ module Templates
x1 = node.x x1 = node.x
y1 = node.y y1 = node.y
x2 = node.x + node.w x2 = node.endx
y2 = node.y + node.h y2 = node.endy
underscore_count = 1 underscore_count = 1
@ -417,8 +386,9 @@ module Templates
break if distance > 0.02 || height_diff > node.h * 0.5 break if distance > 0.02 || height_diff > node.h * 0.5
underscore_count += 1 underscore_count += 1
next_x2 = next_node.x + next_node.w
next_y2 = next_node.y + next_node.h next_x2 = next_node.endx
next_y2 = next_node.endy
x2 = next_x2 x2 = next_x2
y2 = [y2, next_y2].max y2 = [y2, next_y2].max
@ -438,8 +408,8 @@ module Templates
def calculate_iou(box1, box2) def calculate_iou(box1, box2)
x1 = [box1.x, box2.x].max x1 = [box1.x, box2.x].max
y1 = [box1.y, box2.y].max y1 = [box1.y, box2.y].max
x2 = [box1.x + box1.w, box2.x + box2.w].min x2 = [box1.endx, box2.endx].min
y2 = [box1.y + box1.h, box2.y + box2.h].min y2 = [box1.endy, box2.endy].min
intersection_width = [0, x2 - x1].max intersection_width = [0, x2 - x1].max
intersection_height = [0, y2 - y1].max intersection_height = [0, y2 - y1].max
@ -455,8 +425,7 @@ module Templates
end end
def boxes_overlap?(box1, box2) def boxes_overlap?(box1, box2)
!(box1.x + box1.w < box2.x || box2.x + box2.w < box1.x || !(box1.endx < box2.x || box2.endx < box1.x || box1.endy < box2.y || box2.endy < box1.y)
box1.y + box1.h < box2.y || box2.y + box2.h < box1.y)
end end
def increase_confidence_for_overlapping_fields(image_fields, text_fields, by: 1.0) def increase_confidence_for_overlapping_fields(image_fields, text_fields, by: 1.0)
@ -465,12 +434,10 @@ module Templates
image_fields.map do |image_field| image_fields.map do |image_field|
next if image_field.type != 'text' next if image_field.type != 'text'
field_bottom = image_field.y + image_field.h
text_fields.each do |text_field| text_fields.each do |text_field|
break if text_field.y > field_bottom break if text_field.y > image_field.endy
next if text_field.y + text_field.h < image_field.y next if text_field.endy < image_field.y
next unless boxes_overlap?(image_field, text_field) next unless boxes_overlap?(image_field, text_field)
next if calculate_iou(image_field, text_field) < 0.4 next if calculate_iou(image_field, text_field) < 0.4

@ -8,6 +8,10 @@ module Templates
def endy def endy
@endy ||= y + h @endy ||= y + h
end end
def endx
@endx ||= x + w
end
end end
MODEL_PATH = Rails.root.join('tmp/model.onnx') MODEL_PATH = Rails.root.join('tmp/model.onnx')
@ -64,9 +68,7 @@ module Templates
detections = apply_nms(detections, nms) detections = apply_nms(detections, nms)
fields = build_fields_from_detections(detections, image) build_fields_from_detections(detections, image)
sort_fields(fields, y_threshold: 10.0 / image.height)
end end
def build_split_image_regions(image) def build_split_image_regions(image)
@ -302,27 +304,6 @@ module Templates
end end
end end
def sort_fields(fields, y_threshold: 0.01)
sorted_fields = fields.sort { |a, b| a.endy == b.endy ? a.x <=> b.x : a.endy <=> b.endy }
lines = []
current_line = []
sorted_fields.each do |field|
if current_line.blank? || (field.y - current_line.first.y).abs < y_threshold
current_line << field
else
lines << current_line.sort_by(&:x)
current_line = [field]
end
end
lines << current_line.sort_by(&:x) if current_line.present?
lines.flatten
end
def apply_nms(detections, threshold = 0.5) def apply_nms(detections, threshold = 0.5)
return detections if detections[:xyxy].shape[0].zero? return detections if detections[:xyxy].shape[0].zero?

Loading…
Cancel
Save