BaoNhan's picture
Update app.py
cec1f49 verified
import numpy as np
import traceback
import torch
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from openai import OpenAI
import time
import os
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("❌ Lỗi: Không tìm thấy GEMINI_API_KEY. Vui lòng cấu hình trong Settings -> Secrets.")
client = OpenAI(
api_key=api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
def md_to_kb_safe(md_text, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2"):
try:
headers_to_split_on = [("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")]
splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
md_chunks = splitter.split_text(md_text)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
final_chunks = text_splitter.split_documents(md_chunks)
texts = [doc.page_content for doc in final_chunks]
device = "cuda" if torch.cuda.is_available() and torch.cuda.memory_allocated() < 2_000_000_000 else "cpu"
embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs={"device": device})
vectors = embedding_model.embed_documents(texts)
kb = [{"text": texts[i], "vector": vectors[i]} for i in range(len(texts))]
return {"success": True, "num_chunks": len(final_chunks), "kb": kb, "embed_model": embedding_model}
except Exception as e:
return {"success": False, "error": str(e), "traceback": traceback.format_exc()}
def cosine_similarity(v1, v2):
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
def semantic_search(query, embed_model, kb, top_k=3):
t0 = time.time()
q_vec = np.array(embed_model.embed_query(query))
scores = [(cosine_similarity(q_vec, item["vector"]), item["text"]) for item in kb]
scores.sort(reverse=True, key=lambda x: x[0])
return scores[:top_k], time.time() - t0
def build_context(results):
ctx = ""
for i, (score, chunk) in enumerate(results):
ctx += f"=== Context {i+1} ===\n{chunk}\n\n"
return ctx
def rag_answer(query, embed_model, kb):
t0 = time.time()
results, t_semantic = semantic_search(query, embed_model, kb, top_k=3)
context = build_context(results)
prompt = f"""Use ONLY the information in the following context.
{context}
Question: {query}
If the answer is not in the context, respond EXACTLY with:
"I do not have enough information to answer that."
"""
response = client.chat.completions.create(
model="gemini-2.5-pro",
temperature=0,
messages=[
{"role": "system", "content": "Answer strictly using the context."},
{"role": "user", "content": prompt}
]
)
answer = response.choices[0].message.content
return answer, t_semantic, time.time() - t0
def evaluate_ai(response, true_answer):
t0 = time.time()
eval_prompt = f"""
AI Response: {response}
Ground Truth: {true_answer}
Rules:
- 1 = very close to true answer
- 0.5 = partially correct
- 0 = incorrect
"""
response = client.chat.completions.create(
model="gemini-2.5-pro",
temperature=0,
messages=[
{"role": "system", "content": "You are an evaluation system."},
{"role": "user", "content": eval_prompt}
]
)
return response.choices[0].message.content, time.time() - t0
def run_rag_pipeline(md_text_input, query, true_answer):
kb_result = md_to_kb_safe(md_text_input)
if not kb_result["success"]:
return f"Error creating KB:\n{kb_result['error']}", None, None
kb = kb_result["kb"]
embed_model = kb_result["embed_model"]
answer, t_semantic, t_rag = rag_answer(query, embed_model, kb)
score, t_eval = evaluate_ai(answer, true_answer)
timings = f"Semantic Search: {t_semantic:.2f}s | LLM Answer: {t_rag:.2f}s | Evaluation: {t_eval:.2f}s"
return answer, score, timings
import base64
import os
import re
import time
import zipfile
from pathlib import Path
import click
import gradio as gr
from gradio_pdf import PDF
from loguru import logger
from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.utils.hash_utils import str_sha256
async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
os.makedirs(output_dir, exist_ok=True)
try:
file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
pdf_data = read_fn(doc_path)
if is_ocr:
parse_method = 'ocr'
else:
parse_method = 'auto'
if backend.startswith("vlm"):
parse_method = "vlm"
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
await aio_do_parse(
output_dir=output_dir,
pdf_file_names=[file_name],
pdf_bytes_list=[pdf_data],
p_lang_list=[language],
parse_method=parse_method,
end_page_id=end_page_id,
formula_enable=formula_enable,
table_enable=table_enable,
backend=backend,
server_url=url,
)
return local_md_dir, file_name
except Exception as e:
logger.exception(e)
return None
def compress_directory_to_zip(directory_path, output_zip_path):
try:
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, directory_path)
zipf.write(file_path, arcname)
return 0
except Exception as e:
logger.exception(e)
return -1
def image_to_base64(image_path):
with open(image_path, 'rb') as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def replace_image_with_base64(markdown_text, image_dir_path):
pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
def replace(match):
relative_path = match.group(1)
full_path = os.path.join(image_dir_path, relative_path)
base64_image = image_to_base64(full_path)
return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
return re.sub(pattern, replace, markdown_text)
async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
file_path = to_pdf(file_path)
local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0:
logger.info('Compression successful')
else:
logger.error('Compression failed')
md_path = os.path.join(local_md_dir, file_name + '.md')
with open(md_path, 'r', encoding='utf-8') as f:
txt_content = f.read()
md_content = replace_image_with_base64(txt_content, local_md_dir)
new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
return md_content, txt_content, archive_zip_path, new_pdf_path
import asyncio
import traceback
async def to_markdown_safe(file_path, end_pages=10, is_ocr=False,
formula_enable=True, table_enable=True,
language="ch", backend="pipeline", url=None):
try:
return await to_markdown(file_path, end_pages, is_ocr,
formula_enable, table_enable,
language, backend, url)
except Exception as e:
err_msg = traceback.format_exc()
logger.error(f"Error in to_markdown: {err_msg}")
return f"Error: {str(e)}", err_msg, None, None
latex_delimiters_type_a = [
{'left': '$$', 'right': '$$', 'display': True},
{'left': '$', 'right': '$', 'display': False},
]
latex_delimiters_type_b = [
{'left': '\\(', 'right': '\\)', 'display': False},
{'left': '\\[', 'right': '\\]', 'display': True},
]
latex_delimiters_type_all = latex_delimiters_type_a + latex_delimiters_type_b
header = """
<html><head><link rel="stylesheet"href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"><style>.link-block{border:1px solid transparent;border-radius:24px;background-color:rgba(54,54,54,1);cursor:pointer!important}.link-block:hover{background-color:rgba(54,54,54,0.75)!important;cursor:pointer!important}.external-link{display:inline-flex;align-items:center;height:36px;line-height:36px;padding:0 16px;cursor:pointer!important}.external-link,.external-link:hover{cursor:pointer!important}a{text-decoration:none}</style></head><body><div style="
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
text-align: center;
background: linear-gradient(45deg, #007bff 0%, #0056b3 100%);
padding: 24px;
gap: 24px;
border-radius: 8px;
"><div style="
display: flex;
flex-direction: column;
align-items: center;
gap: 16px;
"><div style="display: flex; flex-direction: column; gap: 8px"><h1 style="
font-size: 48px;
color: #fafafa;
margin: 0;
font-family: 'Trebuchet MS', 'Lucida Sans Unicode',
'Lucida Grande', 'Lucida Sans', Arial, sans-serif;
">MinerU 2.5:PDF Extraction Demo</h1></div></div><p style="
margin: 0;
line-height: 1.6rem;
font-size: 16px;
color: #fafafa;
opacity: 0.8;
">A one-stop,open-source,high-quality data extraction tool that supports converting PDF to Markdown and JSON.<br></p><style>.link-block{display:inline-block}.link-block+.link-block{margin-left:20px}</style><div class="column has-text-centered"><div class="publication-links"><!--Code Link.--><span class="link-block"><a href="https://github.com/opendatalab/MinerU"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 4px"><i class="fab fa-github"style="color: white; margin-right: 4px"></i></span><span style="color: white">Code</span></a></span><!--Code Link.--><span class="link-block"><a href="https://huggingface.co/opendatalab/MinerU2.5-2509-1.2B"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 4px"><i class="fas fa-archive"style="color: white; margin-right: 4px"></i></span><span style="color: white">Model</span></a></span><!--arXiv Link.--><span class="link-block"><a href="https://arxiv.org/abs/2409.18839"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-file"style="color: white"></i></span><span style="color: white">Paper</span></a></span><!--Homepage Link.--><span class="link-block"><a href="https://mineru.net/home?source=online"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-home"style="color: white"></i></span><span style="color: white">Homepage</span></a></span><!--Client Link.--><span class="link-block"><a href="https://mineru.net/client?source=online"class="external-link button is-normal is-rounded is-dark"style="text-decoration: none; cursor: pointer"><span class="icon"style="margin-right: 8px"><i class="fas fa-download"style="color: white"></i></span><span style="color: white">Download</span></a></span></div></div><!--New Demo Links--></div></body></html>
"""
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
]
other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka', "el", "th"]
add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
all_lang = []
all_lang.extend([*other_lang, *add_lang])
def safe_stem(file_path):
stem = Path(file_path).stem
return re.sub(r'[^\w.]', '_', stem)
def to_pdf(file_path):
if file_path is None:
return None
pdf_bytes = read_fn(file_path)
unique_filename = f'{safe_stem(file_path)}.pdf'
tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
with open(tmp_file_path, 'wb') as tmp_pdf_file:
tmp_pdf_file.write(pdf_bytes)
return tmp_file_path
def update_interface(backend_choice):
if backend_choice in ["vlm-transformers", "vlm-vllm-async-engine"]:
return gr.update(visible=False), gr.update(visible=False)
elif backend_choice in ["vlm-http-client"]:
return gr.update(visible=True), gr.update(visible=False)
elif backend_choice in ["pipeline"]:
return gr.update(visible=False), gr.update(visible=True)
else:
pass
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option(
'--enable-example',
'example_enable',
type=bool,
help="Enable example files for input."
"The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
default=True,
)
@click.option(
'--enable-vllm-engine',
'vllm_engine_enable',
type=bool,
help="Enable vLLM engine backend for faster processing.",
default=False,
)
@click.option(
'--enable-api',
'api_enable',
type=bool,
help="Enable gradio API for serving the application.",
default=True,
)
@click.option(
'--max-convert-pages',
'max_convert_pages',
type=int,
help="Set the maximum number of pages to convert from PDF to Markdown.",
default=1000,
)
@click.option(
'--server-name',
'server_name',
type=str,
help="Set the server name for the Gradio app.",
default=None,
)
@click.option(
'--server-port',
'server_port',
type=int,
help="Set the server port for the Gradio app.",
default=None,
)
@click.option(
'--latex-delimiters-type',
'latex_delimiters_type',
type=click.Choice(['a', 'b', 'all']),
help="Set the type of LaTeX delimiters to use in Markdown rendering:"
"'a' for type '$', 'b' for type '()[]', 'all' for both types.",
default='all',
)
def main(ctx,
example_enable, vllm_engine_enable, api_enable, max_convert_pages,
server_name, server_port, latex_delimiters_type, **kwargs
):
kwargs.update(arg_parse(ctx))
if latex_delimiters_type == 'a':
latex_delimiters = latex_delimiters_type_a
elif latex_delimiters_type == 'b':
latex_delimiters = latex_delimiters_type_b
elif latex_delimiters_type == 'all':
latex_delimiters = latex_delimiters_type_all
else:
raise ValueError(f"Invalid latex delimiters type: {latex_delimiters_type}.")
if vllm_engine_enable:
try:
print("Start init vLLM engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton
model_singleton = ModelSingleton()
predictor = model_singleton.get_model(
"vllm-async-engine",
None,
None,
**kwargs
)
print("vLLM engine init successfully.")
except Exception as e:
logger.exception(e)
suffixes = [f".{suffix}" for suffix in pdf_suffixes + image_suffixes]
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column(variant='panel', scale=5):
with gr.Row():
input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
with gr.Row():
max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
with gr.Row():
if vllm_engine_enable:
drop_list = ["pipeline", "vlm-vllm-async-engine"]
preferred_option = "vlm-vllm-async-engine"
else:
drop_list = ["pipeline", "vlm-transformers", "vlm-http-client"]
preferred_option = "pipeline"
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
with gr.Row(visible=False) as client_options:
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("**Recognition Options:**")
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
with gr.Column(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch')
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
with gr.Row():
change_bu = gr.Button('Convert')
clear_bu = gr.ClearButton(value='Clear')
pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
if example_enable:
example_root = os.path.join(os.getcwd(), 'examples')
if os.path.exists(example_root):
with gr.Accordion('Examples:'):
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith(tuple(suffixes))],
inputs=input_file
)
with gr.Column(variant='panel', scale=5):
output_file = gr.File(label='convert result', interactive=False)
with gr.Tabs():
with gr.Tab('Markdown rendering'):
md = gr.Markdown(label='Markdown rendering', height=1100, show_copy_button=True,
latex_delimiters=latex_delimiters,
line_breaks=True)
with gr.Tab('Markdown text'):
md_text = gr.TextArea(lines=45, show_copy_button=True)
with gr.Tab("RAG QA"):
rag_md_text = gr.TextArea(label="Paste Markdown here", lines=15)
rag_query = gr.Textbox(label="Your Question")
rag_true = gr.Textbox(label="Ground Truth Answer (optional)")
rag_run = gr.Button("Run RAG")
rag_answer_out = gr.TextArea(label="RAG Answer", lines=15, interactive=False)
rag_score_out = gr.Textbox(label="Evaluation Score")
rag_timing_out = gr.Textbox(label="Timings")
rag_run.click(
fn=run_rag_pipeline,
inputs=[rag_md_text, rag_query, rag_true],
outputs=[rag_answer_out, rag_score_out, rag_timing_out]
)
backend.change(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
demo.load(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
if api_enable:
api_name = None
else:
api_name = False
input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
change_bu.click(
fn=lambda *args: asyncio.run(to_markdown_safe(*args)),
inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
outputs=[md, md_text, output_file, pdf_show],
api_name=api_name
)
demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable, height=1200)
if __name__ == "__main__":
main()