龙空技术网

揭秘PDF解析方法:OCR-Free的小模型方法

冰镇火锅聊AI 526

前言:

此刻你们对“tfidf算法实现伪代码”大致比较讲究,小伙伴们都需要分析一些“tfidf算法实现伪代码”的相关知识。那么小编在网摘上汇集了一些对于“tfidf算法实现伪代码””的相关资讯,希望你们能喜欢,小伙伴们一起来学习一下吧!

PDF 文件很难转换成其他格式,通常会将大量信息锁定在 AI 应用程序无法访问的格式中。如果我们可以将 PDF 文件或其对应的图像转换为机器可读的结构化或半结构化格式,这将大大缓解这一问题。这也可以显著增强人工智能应用程序的知识库。

在本系列的第一篇文章中,我们介绍了 PDF 解析的主要任务,对现有方法进行了分类,并对每种方法进行了简要介绍。在本系列的第二篇文章中,我们重点介绍了基于管道的方法。

本文是该系列的第三篇,介绍另一种 PDF 解析方法:OCR-Free 的小模型方法。首先介绍一下概述,然后介绍各种有代表性的OCR-Free 小模型的 PDF 解析方案的原理,最后分享我们得到的一些感悟和思考。

请注意,本文中提到的“小模型”与大型多模态模型相比相对较小,通常具有少于 30 亿个参数。

概述

之前介绍的基于流水线的PDF解析方法主要利用OCR引擎进行文本识别,但其计算成本高,对语言和文档类型的灵活性不高,并且可能存在OCR错误,影响后续任务。

因此,应该开发OCR-Free 的方法,如图 1 所示。它们不使用 OCR 来显式识别文本。相反,它们使用神经网络来隐式完成任务。本质上,这些方法采用端到端方法,直接输出 PDF 解析的结果。

图 1:基于无 OCR 小模型的方法。图片由作者提供。

从结构上看,OCR-free 方法相比基于流水线的方法更加简单,OCR-free 方法需要关注的重点是模型结构设计和训练数据的构建。

接下来我们来介绍几个有代表性的基于OCR-Free小模型的PDF解析框架:

Donut: OCR-Free 文档理解转换器。Nougat:基于Donut架构,与PDF文件、公式和表格一起使用时特别有效。Pix2Struct:屏幕截图解析作为视觉语言理解的预训练。Donut

如图 2 所示,Donut是一个端到端模型,旨在全面理解文档图像。其架构简单明了,由基于 Transformer 的视觉编码器和文本解码器模块组成。

图 2:Donut 的架构。来源:Donut。

Donut 不依赖任何与 OCR 相关的模块,而是使用视觉编码器从文档图像中提取特征,然后使用文本解码器直接生成 token 序列,再将输出序列转换为 JSON 等结构化格式。

代码如下:

class DonutModel(PreTrainedModel):    r"""    Donut: an E2E OCR-free Document Understanding Transformer.    The encoder maps an input document image into a set of embeddings,    the decoder predicts a desired token sequence, that can be converted to a structured format,    given a prompt and the encoder output embeddings    """    config_class = DonutConfig    base_model_prefix = "donut"    def __init__(self, config: DonutConfig):        super().__init__(config)        self.config = config        self.encoder = SwinEncoder(            input_size=self.config.input_size,            align_long_axis=self.config.align_long_axis,            window_size=self.config.window_size,            encoder_layer=self.config.encoder_layer,            name_or_path=self.config.name_or_path,        )        self.decoder = BARTDecoder(            max_position_embeddings=self.config.max_position_embeddings,            decoder_layer=self.config.decoder_layer,            name_or_path=self.config.name_or_path,        )    def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_labels: torch.Tensor):        """        Calculate a loss given an input image and a desired token sequence,        the model will be trained in a teacher-forcing manner        Args:            image_tensors: (batch_size, num_channels, height, width)            decoder_input_ids: (batch_size, sequence_length, embedding_dim)            decode_labels: (batch_size, sequence_length)        """        encoder_outputs = self.encoder(image_tensors)        decoder_outputs = self.decoder(            input_ids=decoder_input_ids,            encoder_hidden_states=encoder_outputs,            labels=decoder_labels,        )        return decoder_outputs    ...    ...
Encoder

Donut 使用Swin-Transformer作为图像编码器,因为它在初步文档解析研究中表现出色。该图像编码器将输入文档图像转换为一组高维嵌入。这些嵌入将用作文本解码器的输入。

对应代码如下。

class SwinEncoder(nn.Module):    r"""    Donut encoder based on SwinTransformer    Set the initial weights and configuration with a pretrained SwinTransformer and then    modify the detailed configurations as a Donut Encoder    Args:        input_size: Input image size (width, height)        align_long_axis: Whether to rotate image if height is greater than width        window_size: Window size(=patch size) of SwinTransformer        encoder_layer: Number of layers of SwinTransformer encoder        name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.                      otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).    """    def __init__(        self,        input_size: List[int],        align_long_axis: bool,        window_size: int,        encoder_layer: List[int],        name_or_path: Union[str, bytes, os.PathLike] = None,    ):        super().__init__()        self.input_size = input_size        self.align_long_axis = align_long_axis        self.window_size = window_size        self.encoder_layer = encoder_layer        self.to_tensor = transforms.Compose(            [                transforms.ToTensor(),                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),            ]        )        self.model = SwinTransformer(            img_size=self.input_size,            depths=self.encoder_layer,            window_size=self.window_size,            patch_size=4,            embed_dim=128,            num_heads=[4, 8, 16, 32],            num_classes=0,        )        self.model.norm = None        # weight init with swin        if not name_or_path:            swin_state_dict = timm.create_model("swin_base_patch4_window12_384", pretrained=True).state_dict()            new_swin_state_dict = self.model.state_dict()            for x in new_swin_state_dict:                if x.endswith("relative_position_index") or x.endswith("attn_mask"):                    pass                elif (                    x.endswith("relative_position_bias_table")                    and self.model.layers[0].blocks[0].attn.window_size[0] != 12                ):                    pos_bias = swin_state_dict[x].unsqueeze(0)[0]                    old_len = int(math.sqrt(len(pos_bias)))                    new_len = int(2 * window_size - 1)                    pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(0, 3, 1, 2)                    pos_bias = F.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False)                    new_swin_state_dict[x] = pos_bias.permute(0, 2, 3, 1).reshape(1, new_len ** 2, -1).squeeze(0)                else:                    new_swin_state_dict[x] = swin_state_dict[x]            self.model.load_state_dict(new_swin_state_dict)    def forward(self, x: torch.Tensor) -> torch.Tensor:        """        Args:            x: (batch_size, num_channels, height, width)        """        x = self.model.patch_embed(x)        x = self.model.pos_drop(x)        x = self.model.layers(x)        return x    ...    ...

Donut 使用BART作为解码器。

class BARTDecoder(nn.Module):    """    Donut Decoder based on Multilingual BART    Set the initial weights and configuration with a pretrained multilingual BART model,    and modify the detailed configurations as a Donut decoder    Args:        decoder_layer:            Number of layers of BARTDecoder        max_position_embeddings:            The maximum sequence length to be trained        name_or_path:            Name of a pretrained model name either registered in huggingface.co. or saved in local,            otherwise, `hyunwoongko/asian-bart-ecjk` will be set (using `transformers`)    """    def __init__(        self, decoder_layer: int, max_position_embeddings: int, name_or_path: Union[str, bytes, os.PathLike] = None    ):        super().__init__()        self.decoder_layer = decoder_layer        self.max_position_embeddings = max_position_embeddings        self.tokenizer = XLMRobertaTokenizer.from_pretrained(            "hyunwoongko/asian-bart-ecjk" if not name_or_path else name_or_path        )        self.model = MBartForCausalLM(            config=MBartConfig(                is_decoder=True,                is_encoder_decoder=False,                add_cross_attention=True,                decoder_layers=self.decoder_layer,                max_position_embeddings=self.max_position_embeddings,                vocab_size=len(self.tokenizer),                scale_embedding=True,                add_final_layer_norm=True,            )        )        self.model.forward = self.forward  #  to get cross attentions and utilize `generate` function        self.model.config.is_encoder_decoder = True  # to get cross-attention        self.add_special_tokens(["<sep/>"])  # <sep/> is used for representing a list in a JSON        self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id        self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference        # weight init with asian-bart        if not name_or_path:            bart_state_dict = MBartForCausalLM.from_pretrained("hyunwoongko/asian-bart-ecjk").state_dict()            new_bart_state_dict = self.model.state_dict()            for x in new_bart_state_dict:                if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024:                    new_bart_state_dict[x] = torch.nn.Parameter(                        self.resize_bart_abs_pos_emb(                            bart_state_dict[x],                            self.max_position_embeddings                            + 2,  #                         )                    )                elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):                    new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :]                else:                    new_bart_state_dict[x] = bart_state_dict[x]            self.model.load_state_dict(new_bart_state_dict)    ...    ...    def forward(        self,        input_ids,        attention_mask: Optional[torch.Tensor] = None,        encoder_hidden_states: Optional[torch.Tensor] = None,        past_key_values: Optional[torch.Tensor] = None,        labels: Optional[torch.Tensor] = None,        use_cache: bool = None,        output_attentions: Optional[torch.Tensor] = None,        output_hidden_states: Optional[torch.Tensor] = None,        return_dict: bool = None,    ):        """        A forward fucntion to get cross attentions and utilize `generate` function        Source:                Args:            input_ids: (batch_size, sequence_length)            attention_mask: (batch_size, sequence_length)            encoder_hidden_states: (batch_size, sequence_length, hidden_size)        Returns:            loss: (1, )            logits: (batch_size, sequence_length, hidden_dim)            hidden_states: (batch_size, sequence_length, hidden_size)            decoder_attentions: (batch_size, num_heads, sequence_length, sequence_length)            cross_attentions: (batch_size, num_heads, sequence_length, sequence_length)        """        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions        output_hidden_states = (            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states        )        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict        outputs = self.model.model.decoder(            input_ids=input_ids,            attention_mask=attention_mask,            encoder_hidden_states=encoder_hidden_states,            past_key_values=past_key_values,            use_cache=use_cache,            output_attentions=output_attentions,            output_hidden_states=output_hidden_states,            return_dict=return_dict,        )        logits = self.model.lm_head(outputs[0])        loss = None        if labels is not None:            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)            loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))        if not return_dict:            output = (logits,) + outputs[1:]            return (loss,) + output if loss is not None else output        return ModelOutput(            loss=loss,            logits=logits,            past_key_values=outputs.past_key_values,            hidden_states=outputs.hidden_states,            decoder_attentions=outputs.attentions,            cross_attentions=outputs.cross_attentions,        )    ...    ...

Donut 使用公开可用的、预先训练的、多语言 BART 模型的权重来初始化解码器模型权重。

文本解码器的输出是生成的标记序列。

训练Training

预训练Pre-training

预训练的目标是最小化下一个标记预测的交叉熵损失。这是通过联合调节图像和先前的上下文来实现的。此任务类似于伪 OCR 任务。该模型本质上是作为视觉语料库(例如文档图像)上的视觉语言模型进行训练的。

使用的训练数据是IIT-CDIP,包含 1100 万张扫描的英文文档图像。同时,使用合成文档生成器 (SynthDoG)生成多语言数据,包括英文、中文、日文和韩文,每种语言生成 50 万张图片。

图 3:使用 SynthDoG 生成的英文、中文、日文和韩文样本。采用启发式随机模式来模仿真实文档。来源:Donut。

生成的示例如图3所示,一个示例由背景、文档、文本、布局几个部分组成。

背景图片来源于 ImageNet 样本文档的纹理源自收集的纸质照片。单词和短语取自维基百科。布局是由一个简单的基于规则的算法产生的,该算法随机排列网格。

此外,还利用各种图像渲染技术来模拟真实文档。

此外,图 4 显示了通过商业 CLOVA OCR API 获取的训练数据的标签。

图 4:训练数据的标签。来源:synthdog-en 数据集的第 0 行。

微调Fine-tuning

微调的首要目的是为了适应下游任务。

例如,在文档分类任务中,解码器被训练生成一个token序列[START class][memo][END class]。该序列可以直接转换成JSON格式,如{"class": "memo"}

Nougat

Nougat是 2023 年 8 月推出的端到端、OCR-Free 的小型模型。它可以直接解析图像内容。它接受从文学作品扫描的图像或从 PDF 转换的图像作为输入,并生成 markdown 作为输出。

模型架构

Nougat是基于 Donut 架构开发的,它通过神经网络隐式识别文本,无需任何 OCR 相关的输入或模块,如图 5 所示。

图 5:遵循 Donut 的端到端架构。Swin Transformer 编码器获取文档图像并将其转换为潜在嵌入,然后以自回归方式将其转换为一系列标记。资料来源:Nougat。

训练数据集的构建

Nougat 的模型并不是特别创新;它的主要重点是构建大型训练数据集,这是一项具有挑战性的任务。

Nougat 通过创建由图像和 markdown 组成的大规模训练数据,实现了经济高效的方法。这是 Nougat 最值得学习的方面。

数据源

由于缺乏包含 PDF 图像和 markdown 对的大规模数据集,Nougat 从三个来源构建了数据集:arXiv、PMC(PubMed Central)和IDL(行业文档库),如图 6 所示。

图 6:Nougat 训练数据的数据源。图片由作者提供,灵感来自Nougat。

总体流程

主要使用ArXiv数据,因为其中包含TeX源代码。处理流程如图7所示。

图 7:Nougat 的数据处理流程,TeX 源和原始 PDF 论文最终转换为图像-Markdown 对。图片由作者提供。

如图 7 所示,主要目标是将现有的资源(即 PDF 论文及其对应的 TeX 源代码)转换为对。每对由每个 PDF 页面的图像及其对应的 Markdown 组成。

获取图像作为输入

获取PDF页面图片的过程比较简单,直接使用PyPDFium2的相关API即可。

def rasterize_paper(    pdf: Union[Path, bytes],    outpath: Optional[Path] = None,    dpi: int = 96,    return_pil=False,    pages=None,) -> Optional[List[io.BytesIO]]:    """    Rasterize a PDF file to PNG images.    Args:        pdf (Path): The path to the PDF file.        outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.        dpi (int, optional): The output DPI. Defaults to 96.        return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.        pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.    Returns:        Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.    """    pils = []    if outpath is None:        return_pil = True    try:        if isinstance(pdf, (str, Path)):            pdf = pypdfium2.PdfDocument(pdf)        if pages is None:            pages = range(len(pdf))        renderer = pdf.render(            pypdfium2.PdfBitmap.to_pil,            page_indices=pages,            scale=dpi / 72,        )        for i, image in zip(pages, renderer):            if return_pil:                page_bytes = io.BytesIO()                image.save(page_bytes, "bmp")                pils.append(page_bytes)            else:                image.save((outpath / ("%02d.png" % (i + 1))), "png")    except Exception as e:        logging.error(e)    if return_pil:        return pils

获取 Markdown 作为标签

如图 7 所示,为了得到 markdown,我们必须先将 TeX 源代码转换成 HTML 文件。然后我们可以解析这些文件并将其格式化为 markdown。

这涉及两个挑战。

第一个挑战是弄清楚如何对 Markdown 进行分页,因为训练数据由来自每个 PDF 页面的图像组成,并使用相应的 Markdown 作为标签。

由于每篇论文的 LaTeX 源文件尚未重新编译,我们无法像 LaTeX 编译器那样自动确定 PDF 文件的分页符。

为了实现这个目标,需要利用当前可用的资源。策略是启发式地将原始 PDF 页面中的文本与 Markdown 文本进行匹配。

具体来说,首先使用 PDFMiner从 PDF 中提取文本行,然后对文本进行预处理以删除页码和潜在的页眉或页脚。然后,使用 PDF 行作为输入、页码作为标签训练 tfidf_transformer 模型。接下来,应用训练后的模型将 Markdown 分为段落并预测每个段落的页码。

def split_markdown(    doc: str,    pdf_file: str,    figure_info: Optional[List[Dict]] = None,    doc_fig: Dict[str, str] = {},    minlen: int = 3,    min_num_words: int = 22,    doc_paragraph_chars: int = 1000,    min_score: float = 0.75,    staircase: bool = True,) -> Tuple[List[str], Dict]:    ...    ...       if staircase:            # train bag of words            page_target = np.zeros(len(paragraphs))            page_target[num_paragraphs[1:-1] - 1] = 1            page_target = np.cumsum(page_target).astype(int)            model = BagOfWords(paragraphs, target=page_target)            labels = model(doc_paragraphs)            # fit stair case function            x = np.arange(len(labels))            stairs = Staircase(len(labels), labels.max() + 1)            stairs.fit(x, labels)            boundaries = (stairs.get_boundaries().astype(int)).tolist()            boundaries.insert(0, 0)        else:            boundaries = [0] * (len(pdf.pages))    ...    ...

最后进行一些收尾调整。

第二个挑战涉及 PDF 中的图表与 Markdown 文件中的位置不一致。

为了解决这个问题,Nougat 最初使用pdffigures2来提取图表。然后将识别出的标题与 TeX 源代码中的标题进行比较,并根据 Levenshtein distance 进行匹配。这种方法使我们能够确定每个图形或表格的 TeX 源代码和页码。这是因为图 7 的 JSON 结构包含图表标题和相应的页码。

一旦将 Markdown 分成单独的页面,之前提取的图表就会重新插入到每个相应页面的末尾。

def split_markdown(    doc: str,    pdf_file: str,    figure_info: Optional[List[Dict]] = None,    doc_fig: Dict[str, str] = {},    minlen: int = 3,    min_num_words: int = 22,    doc_paragraph_chars: int = 1000,    min_score: float = 0.75,    staircase: bool = True,) -> Tuple[List[str], Dict]:    ...    ...    # Reintroduce figures, tables and footnotes    figure_tex = list(doc_fig.keys()), list(doc_fig.values())    if len(doc_fig) > 0:        iterator = figure_info.values() if type(figure_info) == dict else [figure_info]        for figure_list in iterator:            if not figure_list:                continue            for i, f in enumerate(figure_list):                if "caption" in f:                    fig_string = f["caption"]                elif "text" in f:                    fig_string = f["text"]                else:                    continue                ratios = []                for tex in figure_tex[1]:                    if f["figType"] == "Table":                        tex = tex.partition(r"\end{table}")[2]                    ratios.append(Levenshtein.ratio(tex, fig_string))                k = np.argmax(ratios)                if ratios[k] < 0.8:                    continue                if f["page"] < len(out) and out[f["page"]] != "":                    out[f["page"]] += "\n\n" + remove_pretty_linebreaks(                        figure_tex[1][k].strip()                    )    for i in range(len(out)):        foot_match = re.findall(r"\[FOOTNOTE(.*?)\]\[ENDFOOTNOTE\]", out[i])        for match in foot_match:            out[i] = out[i].replace(                "[FOOTNOTE%s][ENDFOOTNOTE]" % match,                doc_fig.get("FOOTNOTE%s" % match, ""),            )        out[i] = re.sub(r"\[(FIGURE|TABLE)(.*?)\](.*?)\[END\1\]", "", out[i])    return out, meta
Pix2Struct

Pix2Struct是一个预先训练的图像转文本模型,专为纯视觉语言理解而设计。此外,它可以针对许多下游任务进行微调。

模型架构

Pix2Struct是一个基于ViT 的图像编码器-文本解码器。

由于论文中没有说明 Pix2Struct 的架构,并且在网上其他地方也找不到,因此我在此提供了一个基于 ViT 架构的参考图,如图 8 所示。

图 8:Pix2Struct 的架构。受到ViT 的启发。

使用标准 ViT方法(在提取固定大小的块之前将输入图像缩放到预定义的分辨率)可能会产生两个负面影响:

它会扭曲真实的纵横比,这可能会对文档、移动 UI 和图形造成很大差异。将模型转移到具有更高分辨率的下游任务变得具有挑战性,因为模型在预训练期间仅观察特定的分辨率。

因此,Pix2Struct 引入了一个小小的增强功能,允许对输入图像进行保持纵横比的缩放,无论是向上还是向下,如图 9 所示。

图 9:可变分辨率输入与典型固定分辨率输入的比较。来源:Pix2Struct。

预训练任务

Pix2Struct 提出了一个屏幕截图解析目标,需要从网页的屏蔽屏幕截图中预测基于 HTML 的解析。

屏蔽输入鼓励对它们的共现进行联合推理。使用简化的 HTML 作为输出是有利的,因为它提供有关文本、图像和布局的清晰信号。

图 10:从原始网页(左)中采样的输入输出对(右)的玩具插图。来源:Pix2Struct。

如图 10 所示,Pix2Struct 提出的截图解析有效地结合了几种著名的预训练策略的信号:

恢复未被掩蔽的部分。这项任务类似于 OCR,这是理解语言的基本技能。Donut 中还提出了使用合成渲染或 OCR 输出进行 OCR 预训练。在图 10 中,预测<C++>就是这种学习信号的一个例子。恢复被遮盖的部分。此任务类似于 BERT 中的遮盖语言建模。然而,一个关键的区别是视觉背景通常会提供额外的有力线索。例如,<Python>图 10 中的预测就是这种信号的一个示例。从图像中恢复替代文本。这是预训练图像标题策略的常用方法。在这种方法中,模型可以使用网页作为额外上下文。例如,img alt=C++如图 10 所示,预测 就是这种学习信号的一个例子。

Pix2Struct 已预先训练了两种模型变体:

由 282M 个参数组成的基础模型。一个由 13 亿个参数组成的大型模型。预训练数据集

预训练的目标是让 Pix2Struct 具备表示输入图像基本结构的能力。为了实现这一点,Pix2Struct 根据C4 语料库中的 URL 生成自监督的输入图像和目标文本对。

Pix2Struct 收集了 8000 万张截图,每张都配有 HTML 源文件。这约占总文档数量的三分之一。每张截图宽度为 1024 像素,高度可调整以匹配内容的高度。获得的 HTML 源文件将转换为简化的 HTML。

图 11 展示了预训练数据的屏幕截图,并附有真实数据和预测解析。

图 11:预训练数据样本。来源:Pix2Struct。

微调

微调 Pix2Struct 的主要步骤涉及预处理下游数据。这可确保图像输入和文本输出准确代表任务。

图 12:视觉语言理解任务示例。这些任务包括图表 QA (AI2D)、应用程序字幕 (Screen2Words) 和文档 QA (DocVQA)。左侧还包括 Pix2Struct 预训练任务的一个示例,即屏幕截图解析。Pix2Struct 对输入图像中的像素进行编码(上图)并解码输出文本(下图)。来源:Pix2Struct。

图 12 描绘了一些下游任务的示例。

关于预处理:

对于 Screen2Words 字幕任务,可以直接使用输入图像和输出文本。对于 DocVQA 视觉问答任务,Pix2Struct 将问题直接作为标题呈现在原始图像的顶部,尽管多模态模型通常为问题保留一个特殊的文本通道。对于多项选择题答案,例如 AI2D,Pix2Struct 选择将其作为标题中问题的一部分呈现。见解和思考

有代表性的OCR-Free解决方案的介绍就到此结束了,现在我们来谈谈感悟和思考。

关于预训练任务

为了全面理解图像或 PDF 中的布局、文本和语义信息,Donut、Nougat 和 Pix2Struct 设计了类似的训练任务:

Donut:图像 → 类似 JSON 的格式Nougat:图片 → MarkdownPix2Struct:蒙版图像 → 简化 HTML

如果我们的目标是开发自己的OCR-Free PDF 解析工具,那么第一步应该是设计训练任务。考虑所需的输出格式和获取相应训练数据的挑战至关重要。

关于预训练数据

训练数据对于OCR-Free 方法至关重要。

获取 Donut 和 Nougat 的训练数据具有挑战性,因为 (图像,JSON) 和 (图像,Markdown) 对并不容易获得。

相反,Pix2Struct 直接改编自来自公共数据集的网页,使数据获取更加方便。然而,由于 Pix2Struct 的训练数据来自网页,因此可能会引入有害内容。多模态模型对此特别敏感。Pix2Struct 尚未实施措施来解决这些有害内容。

如果我们的目标是开发一个OCR-Free 的 PDF 解析工具,一种策略是使用公共数据逐步构建(输入,输出)对进行训练。

此外,确定输入图像的适当分辨率以及一张图像中包含的 PDF 页面数量也是重要的考虑因素。

关于性能

Donut 和 Pix2Struct 都是通用的预训练模型,支持各种下游任务,因此它们的评估方法都是基于这些任务的基准。

根据 Pix2Struct 的实验,其性能在多个任务上显著优于 Donut,并且在大多数任务上也超过了最新水平(SOTA),如图 13 所示:

图 13:Pix2Struct 在 9 个基准测试中的 8 个上优于之前的视觉方法,其中 6 个获得了 SOTA 结果。来源:Pix2Struct。

尽管如此,图 13 所示的这些任务与我们之前定义的 PDF 解析任务有所不同。在这方面,Nougat 更加专业。

Nougat 主要专注于 Markdown 的端到端制作,因此其评估方案主要包括编辑距离、BLEU、METEOR 和 F-measure,如图 14 所示。

图 14:arXiv 测试集上的结果。来源:Nougat。

此外,Nougat 可以比其他工具更准确地将复杂元素(如公式和表格)解析为 LaTeX 源代码,如图 15 和 16 所示。

图 15:使用 Nougat 解析公式的结果。作者截图。

图 16:使用 Nougat 解析表格和公式的结果。作者截图。

此外,Nougat可以方便地获取表格标题并将其与相应的表格关联起来。

Pipeline-Based vs. OCR-Free

图17比较了两种方法的整体架构和性能。左上角表示基于流水线的方法,左下角表示Donut模型。

图 17:两种方法的整体架构和性能比较。来源:Donut。

如图 17 右侧所示,与基于管道的方法相比,Donut 占用的存储空间更少,准确率更高。不过,Donut 的运行速度较慢。其他无 OCR 解决方案与 Donut 类似。

OCR-Free小模型方法的局限性虽然基于流水线的方法涉及多个模型,但每个模型都是轻量级的,总参数量甚至可能比OCR-Free模型少很多。这个因素导致OCR-Free模型的解析速度较慢,这对大规模部署构成挑战。例如,尽管Nougat是一个小模型,但它的参数量有250MB或350MB。然而,它的生成速度很慢,正如Nougat的论文中所述构建此方法的训练数据集成本高昂。这是因为需要构建大规模图像-文本对。此外,它需要更多 GPU 和更长的训练时间,从而增加了机器成本。此外,端到端方法无法针对特定的坏情况进行优化,导致优化成本更高。在基于流水线的解决方案中,如果表格处理模块表现不佳,则只需优化此模块。但是,对于端到端解决方案,在不改变模型架构的情况下,必须创建新的微调数据。这可能会在其他场景中导致新的坏情况,例如公式识别。结论

本文概述了 PDF 解析中基于OCR-Free 小模型的方法。它以三个代表性模型为例深入研究了这种方法,并提供了详细的介绍并分享了由此得出的见解。

总体而言,使用OCR-Free 的小模型 PDF 解析方法的好处在于其一步到位,避免了中间步骤可能带来的任何潜在损害。然而,其有效性在很大程度上依赖于多模态模型的结构和训练数据的质量。此外,它的训练和推理速度较慢,因此实用性不如基于流水线的方法。此外,该方法的可解释性不如基于流水线的方法强。

尽管还有改进空间,但OCR-Free 方法在表格和公式识别等领域表现良好。这些优势为我们构建自己的 PDF 解析工具提供了宝贵的见解。

参考:

标签: #tfidf算法实现伪代码