Comments (3)
You can run inference on the base-model, which has not been fine tuned to any json schema, to do an OCR prediction just like in the pre-training task.
Here is a code snippet that should get you started:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
model_path = "donut-base"
processor = DonutProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
# scale = 1 # optionally change ingput image scale
# h, w = 1263 // scale, 893 // scale
# processor.image_processor.size = {"height": h, "width": w}
# model.config.encoder.image_size = [h, w]
max_new_tokens = 1024 # increase to get more text
image = Image.open("../test-document.png").convert("RGB")
task_prompt = "<s_iitcdip>" # Prompt of pretraining, can be reused for OCR
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
predict_ids = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
use_cache=True,
bad_words_ids=[[0, 1, 2, 3, 57522]],
eos_token_id=2,
pad_token_id=1,
)
predict_seq = processor.tokenizer.batch_decode(predict_ids)
print(predict_seq)
from donut.
@felixvor thanks for sharing the code it works. There's a little problem in my documents where there are some handwritten texts in the document, it's not able to pick it up.
Here's my code:
import re
import torch
import requests
from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
from PIL import Image
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
image = Image.open("/kaggle/input/invoice/inv-h-4.png").convert("RGB")
# prepare decoder inputs
task_prompt = "<s_iitcdip>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode the generated sequences to text
decoded_results = processor.tokenizer.batch_decode(outputs.sequences)
# Print or process the extracted text
print(decoded_results)
And the output looks as follows:
['<s_iitcdip> Invoice Date 02/01/2024 Invoice Number 168808 Invoice Luliano Chemtrade Logistics 6300 Oldfield Rd Niagara Falls, ON L2G 3J8 To ensure proper application, please reference this invoice number on your remittance advice. PLEASE REMIT PAYMENT TO: Allied Universal Security Services of Canada 5580 Explorer Drive Suite 300 Mississauga, ON L4W 4Y 1 Total Amount Due: (CAD) $1,466.56 Terms: Net 30 Days Service Location: 653438 CHEMTRADE LOGISTICS 6300 oldfield rd Niagara Falls, ON L2G 3J8 640658 Billing Period: 01/01/2024 - 01/31/2024 Petrol: Chemtrade - Niagara Patrol: Chemtrade - Niagara - Stat Subtotal Sales Tax Subtotal Total for - CHEMTRADE LOGISTICS : PATROL - CHEMTRADE LOGISTICS Descripti G/L or PO#(SOO)////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////</s>']
When I process the same document using fine-tuned version naver-clova-ix/donut-base-finetuned-docvqa
and ask the question related to handwritten text, it gives correct answer. Here's my working code of DocVQA:
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import torch
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# load document image from the DocVQA dataset
# dataset = load_dataset("hf-internal-testing/example-documents", split="test")
# image = dataset[0]["image"]
image = Image.open("/kaggle/input/invoice/inv-h-4.png").convert("RGB")
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
question = "When is the GR number?"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
# task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
# questions = ["when is the PO#?", "what is the vendor number?"]
# prompts = [task_prompt.replace("{user_input}", question) for question in questions]
# decoder_input_ids = processor.tokenizer(prompts, add_special_tokens=False, padding=True, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
print(processor.token2json(sequence))
Is there any way I can parse complete text including handwritten text using the base model or take the complete text output from the fine-tuned model?
from donut.
In my opinion the strength of donut is not it's ocr generation but the possibility to fine tune on specific tasks. At the moment I can't think of a straight forward way to use the qa model for ocr generation. Maybe it could work somehow but I don't think it would be efficient.
However, there are specified ocr generation transformers out there you could check out (like for example trocr). And if you are looking for highest quality OCR, and if its permissable in your use case, you could check out cloud solutions like Google DocumentAI or Azure Document Parsers, which do a great job with handwriting in my experience.
from donut.
Related Issues (20)
- Does synthdog data has MiT or afl-3.0 license? HOT 1
- Error "A configuraton of type donut cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed" when run inference after finetuned docvqa without pushing to hugging face? HOT 1
- custom json schema - ASAP HOT 2
- Multi GPU support for fine tuning
- confidence 값의 공식적인 지원
- Classification inference
- Update donut-python Python Package to be compatible with latest versions of transformers
- donut inference시 sub task가 변경?
- Not getting prediction correctly using the model trained on the custom dataset (similar format as CORD-V2 dataset) HOT 6
- not work this app.py
- Can synthdog insert text for a specified bbox? HOT 1
- Where is the fine-tuned model?
- Why is the output of the intermediate verification empty after training?
- Donut generate ONLY <s><s>...<s></s> HOT 5
- Performance of the model HOT 1
- How to improve OCR accuracy for Japanese characters? HOT 2
- Early Stopping
- How many documents(invoices) are required for training model for document information extraction?
- What should be the configuration of the machine to train the model?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from donut.