adjust detection

pull/402/head
Pete Matsyburka 2 weeks ago
parent e6280e8f5a
commit c1a5c91299

@ -49,7 +49,15 @@ class Pdfium
end
end
LineNode = Struct.new(:x, :y, :w, :h, :tilt)
LineNode = Struct.new(:x, :y, :w, :h, :tilt) do
def endy
@endy ||= y + h
end
def endx
@endx ||= x + w
end
end
# rubocop:disable Naming/ClassAndModuleCamelCase
class FPDF_LIBRARY_CONFIG < FFI::Struct
@ -480,7 +488,11 @@ class Pdfium
@text_nodes << TextNode.new(char, x, y, node_width, node_height)
end
@text_nodes = @text_nodes.sort { |a, b| a.y == b.y ? a.x <=> b.x : a.y <=> b.y }
y_threshold = 4.0 / width
@text_nodes = @text_nodes.sort do |a, b|
(a.endy - b.endy).abs < y_threshold ? a.x <=> b.x : a.endy <=> b.endy
end
ensure
Pdfium.FPDFText_ClosePage(text_page) if text_page && !text_page.null?
end
@ -549,7 +561,7 @@ class Pdfium
@line_nodes << LineNode.new(norm_x, norm_y, norm_w, norm_h, tilt)
end
@line_nodes = @line_nodes.sort { |a, b| a.y == b.y ? a.x <=> b.x : a.y <=> b.y }
@line_nodes = @line_nodes.sort { |a, b| a.endy == b.endy ? a.x <=> b.x : a.endy <=> b.endy }
end
def close

@ -4,7 +4,16 @@ module Templates
module DetectFields
module_function
TextFieldBox = Struct.new(:x, :y, :w, :h, keyword_init: true)
TextFieldBox = Struct.new(:x, :y, :w, :h, keyword_init: true) do
def endy
@endy ||= y + h
end
def endx
@endx ||= x + w
end
end
PageNode = Struct.new(:prev, :next, :elem, :page, :attachment_uuid, keyword_init: true)
DATE_REGEXP = /
@ -49,6 +58,9 @@ module Templates
\s*[:-]?\s*\z
/ix
LINEBREAK = ["\n", "\r"].freeze
CHECBOXES = ['☐', '□'].freeze
# rubocop:disable Metrics, Style
def call(io, attachment: nil, confidence: 0.3, temperature: 1, inference: Templates::ImageToFields,
nms: 0.1, split_page: false, aspect_ratio: true, padding: 20, regexp_type: true, &)
@ -71,6 +83,8 @@ module Templates
fields = inference.call(image, confidence:, nms:, split_page:,
temperature:, aspect_ratio:, padding:)
fields = sort_fields(fields, y_threshold: 10.0 / image.height)
fields = fields.map do |f|
{
uuid: SecureRandom.uuid,
@ -113,6 +127,8 @@ module Templates
text_fields = extract_text_fields_from_page(page)
line_fields = extract_line_fields_from_page(page)
fields = sort_fields(fields, y_threshold: 10.0 / page.height)
fields = increase_confidence_for_overlapping_fields(fields, text_fields)
fields = increase_confidence_for_overlapping_fields(fields, line_fields)
@ -153,6 +169,12 @@ module Templates
doc.close
end
def sort_fields(fields, y_threshold: 0.01)
fields.sort do |a, b|
(a.endy - b.endy).abs < y_threshold ? a.x <=> b.x : a.endy <=> b.endy
end
end
def print_debug(head_node)
current_node = head_node
index = 0
@ -189,121 +211,68 @@ module Templates
def build_page_nodes(page, fields, tail_node, attachment_uuid: nil)
field_nodes = []
current_text = ''.b
text_nodes = page.text_nodes
y_theshold = 10.0 / page.height
text_idx = 0
field_idx = 0
text_nodes = page.text_nodes
while text_idx < text_nodes.length || field_idx < fields.length
text_node = text_nodes[text_idx]
field = fields[field_idx]
current_field = fields.shift
process_text_node = false
process_field_node = false
index = 0
if text_node && field
text_y_center = text_node.y + (text_node.h / 2.0)
field_y_center = field.y + (field.h / 2.0)
y_threshold = text_node.h / 2.0
vertical_distance = (text_y_center - field_y_center).abs
prev_node = nil
if vertical_distance < y_threshold
is_underscore = text_node.content == '_'
is_left_of_field = text_node.x < field.x
loop do
node = text_nodes[index]
if is_underscore && is_left_of_field
text_x_end = text_node.x + text_node.w
break unless node
distance = field.x - text_x_end
proximity_threshold = text_node.w * 3.0
loop do
break unless current_field
if ((current_field.endy - node.endy).abs < y_theshold &&
(current_field.x <= node.x || node.content.in?(LINEBREAK))) ||
current_field.endy < node.y
field_node = PageNode.new(prev: tail_node, elem: current_field, page: page.page_index, attachment_uuid:)
tail_node.next = field_node
tail_node = field_node
field_nodes << tail_node
if distance < proximity_threshold
process_field_node = true
current_field = fields.shift
else
process_text_node = true
break
end
elsif is_left_of_field
process_text_node = true
else
process_field_node = true
end
elsif text_node.y < field.y
process_text_node = true
else
process_field_node = true
end
if tail_node.elem.is_a?(Templates::ImageToFields::Field)
text_node = PageNode.new(prev: tail_node, elem: ''.b, page: page.page_index, attachment_uuid:)
tail_node.next = text_node
elsif text_node
process_text_node = true
elsif field
process_field_node = true
tail_node = text_node
end
if process_field_node
unless current_text.empty?
new_text_node = PageNode.new(prev: tail_node, elem: current_text, page: page.page_index, attachment_uuid:)
tail_node.next = new_text_node
tail_node = new_text_node
current_text = ''.b
if prev_node && (node.endy - prev_node.endy) > y_theshold && LINEBREAK.exclude?(prev_node.content)
tail_node.elem << "\n"
end
new_field_node = PageNode.new(prev: tail_node, elem: field, page: page.page_index, attachment_uuid:)
tail_node.next = new_field_node
tail_node = new_field_node
field_nodes << tail_node
while text_idx < text_nodes.length
text_node_to_check = text_nodes[text_idx]
is_part_of_field = false
if text_node_to_check.content == '_'
check_y_center = text_node_to_check.y + (text_node_to_check.h / 2.0)
check_y_dist = (check_y_center - field_y_center).abs
check_y_thresh = text_node_to_check.h / 2.0
if check_y_dist < check_y_thresh
padding = text_node_to_check.w * 3.0
field_x_start = field.x - padding
field_x_end = field.x + field.w + padding
text_x_start = text_node_to_check.x
text_x_end = text_node_to_check.x + text_node_to_check.w
is_part_of_field = true if text_x_start <= field_x_end && field_x_start <= text_x_end
end
if node.content != '_' || !tail_node.elem.ends_with?('___')
tail_node.elem << node.content unless CHECBOXES.include?(node.content)
end
break unless is_part_of_field
prev_node = node
text_idx += 1
index += 1
end
field_idx += 1
elsif process_text_node
if text_idx > 0
prev_text_node = text_nodes[text_idx - 1]
x_gap = text_node.x - (prev_text_node.x + prev_text_node.w)
gap_w = text_node.w > prev_text_node.w ? text_node.w : prev_text_node.w
current_text << ' ' if x_gap > gap_w * 2
end
loop do
break unless current_field
current_text << text_node.content
text_idx += 1
end
end
field_node = PageNode.new(prev: tail_node, elem: current_field, page: page.page_index, attachment_uuid:)
tail_node.next = field_node
tail_node = field_node
field_nodes << tail_node
unless current_text.empty?
new_text_node = PageNode.new(prev: tail_node, elem: current_text, page: page.page_index, attachment_uuid:)
tail_node.next = new_text_node
tail_node = new_text_node
current_field = fields.shift
end
[field_nodes, tail_node]
@ -399,8 +368,8 @@ module Templates
x1 = node.x
y1 = node.y
x2 = node.x + node.w
y2 = node.y + node.h
x2 = node.endx
y2 = node.endy
underscore_count = 1
@ -417,8 +386,9 @@ module Templates
break if distance > 0.02 || height_diff > node.h * 0.5
underscore_count += 1
next_x2 = next_node.x + next_node.w
next_y2 = next_node.y + next_node.h
next_x2 = next_node.endx
next_y2 = next_node.endy
x2 = next_x2
y2 = [y2, next_y2].max
@ -438,8 +408,8 @@ module Templates
def calculate_iou(box1, box2)
x1 = [box1.x, box2.x].max
y1 = [box1.y, box2.y].max
x2 = [box1.x + box1.w, box2.x + box2.w].min
y2 = [box1.y + box1.h, box2.y + box2.h].min
x2 = [box1.endx, box2.endx].min
y2 = [box1.endy, box2.endy].min
intersection_width = [0, x2 - x1].max
intersection_height = [0, y2 - y1].max
@ -455,8 +425,7 @@ module Templates
end
def boxes_overlap?(box1, box2)
!(box1.x + box1.w < box2.x || box2.x + box2.w < box1.x ||
box1.y + box1.h < box2.y || box2.y + box2.h < box1.y)
!(box1.endx < box2.x || box2.endx < box1.x || box1.endy < box2.y || box2.endy < box1.y)
end
def increase_confidence_for_overlapping_fields(image_fields, text_fields, by: 1.0)
@ -465,12 +434,10 @@ module Templates
image_fields.map do |image_field|
next if image_field.type != 'text'
field_bottom = image_field.y + image_field.h
text_fields.each do |text_field|
break if text_field.y > field_bottom
break if text_field.y > image_field.endy
next if text_field.y + text_field.h < image_field.y
next if text_field.endy < image_field.y
next unless boxes_overlap?(image_field, text_field)
next if calculate_iou(image_field, text_field) < 0.4

@ -8,6 +8,10 @@ module Templates
def endy
@endy ||= y + h
end
def endx
@endx ||= x + w
end
end
MODEL_PATH = Rails.root.join('tmp/model.onnx')
@ -64,9 +68,7 @@ module Templates
detections = apply_nms(detections, nms)
fields = build_fields_from_detections(detections, image)
sort_fields(fields, y_threshold: 10.0 / image.height)
build_fields_from_detections(detections, image)
end
def build_split_image_regions(image)
@ -302,27 +304,6 @@ module Templates
end
end
def sort_fields(fields, y_threshold: 0.01)
sorted_fields = fields.sort { |a, b| a.endy == b.endy ? a.x <=> b.x : a.endy <=> b.endy }
lines = []
current_line = []
sorted_fields.each do |field|
if current_line.blank? || (field.y - current_line.first.y).abs < y_threshold
current_line << field
else
lines << current_line.sort_by(&:x)
current_line = [field]
end
end
lines << current_line.sort_by(&:x) if current_line.present?
lines.flatten
end
def apply_nms(detections, threshold = 0.5)
return detections if detections[:xyxy].shape[0].zero?

Loading…
Cancel
Save