-
Notifications
You must be signed in to change notification settings - Fork 78
feat: support synthesizing masked fill_in_blank QA pairs #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8919dc9
41d5327
40a04d6
524cc79
59b7551
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Generate Masked Fill-in-blank QAs | ||
| In this module, we generate fill-in-blank QAs from unstructured corpora by randomly masking core entities in a knowledge graph. The key is that a rule-based validator can automatically verify the answers to these questions. For example: | ||
| > **Question:** Hematogenous long-bone osteomyelitis is an infection of the bone, primarily affecting the long bones, and often results from blood-borne pathogens. This condition is characterized by several key symptoms, including ___ and swelling. ___ is a prominent symptom in both primary and recurrent cases of hematogenous long-bone osteomyelitis, manifesting as persistent discomfort in the affected area. | ||
| > **Answer:** pain | ||
|
|
||
| Because the answer of these questions can be easily verified, they are well-suited for RLVR (Reinforcement Learning with Verifiable Rewards). | ||
|
|
||
| For more details, please see our paper "Knowledge-to-Verification: Exploring RLVR for LLMs in Knowledge-Intensive Domains". It has been accepted to the ACL 2026 Main Conference, and we will update the link soon. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| python3 -m graphgen.run \ | ||
| --config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| global_params: | ||
| working_dir: cache | ||
| graph_backend: networkx # graph database backend, support: kuzu, networkx | ||
| kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv | ||
|
|
||
| nodes: | ||
| - id: read_files # id is unique in the pipeline, and can be referenced by other steps | ||
| op_name: read | ||
| type: source | ||
| dependencies: [] | ||
| params: | ||
| input_path: | ||
| - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples | ||
|
|
||
| - id: chunk_documents | ||
| op_name: chunk | ||
| type: map_batch | ||
| dependencies: | ||
| - read_files | ||
| execution_params: | ||
| replicas: 4 | ||
| params: | ||
| chunk_size: 1024 # chunk size for text splitting | ||
| chunk_overlap: 100 # chunk overlap for text splitting | ||
|
|
||
| - id: build_kg | ||
| op_name: build_kg | ||
| type: map_batch | ||
| dependencies: | ||
| - chunk_documents | ||
| execution_params: | ||
| replicas: 1 | ||
| batch_size: 128 | ||
|
|
||
| - id: partition | ||
| op_name: partition | ||
| type: aggregate | ||
| dependencies: | ||
| - build_kg | ||
| params: | ||
| method: quintuple | ||
|
|
||
| - id: generate | ||
| op_name: generate | ||
| type: map_batch | ||
| dependencies: | ||
| - partition | ||
| execution_params: | ||
| replicas: 1 | ||
| batch_size: 128 | ||
| save_output: true # save output | ||
| params: | ||
| method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa | ||
| data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,134 @@ | ||||||||||||||
| import random | ||||||||||||||
| import re | ||||||||||||||
| from typing import Any, Optional | ||||||||||||||
|
|
||||||||||||||
| from graphgen.bases import BaseGenerator | ||||||||||||||
| from graphgen.templates import AGGREGATED_GENERATION_PROMPT | ||||||||||||||
| from graphgen.utils import detect_main_language, logger | ||||||||||||||
|
|
||||||||||||||
| random.seed(42) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class MaskedFillInBlankGenerator(BaseGenerator): | ||||||||||||||
| """ | ||||||||||||||
| Masked Fill-in-blank Generator follows a TWO-STEP process: | ||||||||||||||
| 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning. | ||||||||||||||
| 2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| @staticmethod | ||||||||||||||
| def build_prompt( | ||||||||||||||
| batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] | ||||||||||||||
| ) -> str: | ||||||||||||||
| """ | ||||||||||||||
| Build prompts for REPHRASE. | ||||||||||||||
| :param batch | ||||||||||||||
| :return: | ||||||||||||||
| """ | ||||||||||||||
| nodes, edges = batch | ||||||||||||||
| entities_str = "\n".join( | ||||||||||||||
| [ | ||||||||||||||
| f"{index + 1}. {node[0]}: {node[1]['description']}" | ||||||||||||||
| for index, node in enumerate(nodes) | ||||||||||||||
| ] | ||||||||||||||
| ) | ||||||||||||||
| relations_str = "\n".join( | ||||||||||||||
| [ | ||||||||||||||
| f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" | ||||||||||||||
| for index, edge in enumerate(edges) | ||||||||||||||
| ] | ||||||||||||||
| ) | ||||||||||||||
| language = detect_main_language(entities_str + relations_str) | ||||||||||||||
|
|
||||||||||||||
| # TODO: configure add_context | ||||||||||||||
| # if add_context: | ||||||||||||||
| # original_ids = [ | ||||||||||||||
| # node["source_id"].split("<SEP>")[0] for node in _process_nodes | ||||||||||||||
| # ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges] | ||||||||||||||
| # original_ids = list(set(original_ids)) | ||||||||||||||
| # original_text = await text_chunks_storage.get_by_ids(original_ids) | ||||||||||||||
| # original_text = "\n".join( | ||||||||||||||
| # [ | ||||||||||||||
| # f"{index + 1}. {text['content']}" | ||||||||||||||
| # for index, text in enumerate(original_text) | ||||||||||||||
| # ] | ||||||||||||||
| # ) | ||||||||||||||
|
Comment on lines
+43
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||
| prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format( | ||||||||||||||
| entities=entities_str, relationships=relations_str | ||||||||||||||
| ) | ||||||||||||||
| return prompt | ||||||||||||||
|
|
||||||||||||||
| @staticmethod | ||||||||||||||
| def parse_rephrased_text(response: str) -> Optional[str]: | ||||||||||||||
| """ | ||||||||||||||
| Parse the rephrased text from the response. | ||||||||||||||
| :param response: | ||||||||||||||
| :return: rephrased text | ||||||||||||||
| """ | ||||||||||||||
| rephrased_match = re.search( | ||||||||||||||
| r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL | ||||||||||||||
| ) | ||||||||||||||
| if rephrased_match: | ||||||||||||||
| rephrased_text = rephrased_match.group(1).strip() | ||||||||||||||
| else: | ||||||||||||||
| logger.warning("Failed to parse rephrased text from response: %s", response) | ||||||||||||||
| return None | ||||||||||||||
| return rephrased_text.strip('"').strip("'") | ||||||||||||||
|
|
||||||||||||||
| @staticmethod | ||||||||||||||
| def parse_response(response: str) -> dict: | ||||||||||||||
| pass | ||||||||||||||
|
Comment on lines
+78
to
+80
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| async def generate( | ||||||||||||||
| self, | ||||||||||||||
| batch: tuple[ | ||||||||||||||
| list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] | ||||||||||||||
| ], | ||||||||||||||
| ) -> list[dict]: | ||||||||||||||
| """ | ||||||||||||||
| Generate QAs based on a given batch. | ||||||||||||||
| :param batch | ||||||||||||||
| :return: QA pairs | ||||||||||||||
| """ | ||||||||||||||
| rephrasing_prompt = self.build_prompt(batch) | ||||||||||||||
| response = await self.llm_client.generate_answer(rephrasing_prompt) | ||||||||||||||
| context = self.parse_rephrased_text(response) | ||||||||||||||
| if not context: | ||||||||||||||
| return [] | ||||||||||||||
|
|
||||||||||||||
| nodes, edges = batch | ||||||||||||||
|
|
||||||||||||||
| assert len(nodes) == 3, ( | ||||||||||||||
| "MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, " | ||||||||||||||
| f"but got {len(nodes)} nodes." | ||||||||||||||
| ) | ||||||||||||||
| assert len(edges) == 2, ( | ||||||||||||||
| "MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, " | ||||||||||||||
| f"but got {len(edges)} edges." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| node1, node2, node3 = nodes | ||||||||||||||
| mask_node = random.choice([node1, node2, node3]) | ||||||||||||||
| mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t") | ||||||||||||||
| mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE) | ||||||||||||||
|
|
||||||||||||||
| match = re.search(mask_pattern, context) | ||||||||||||||
| if match: | ||||||||||||||
| gth = match.group(0) | ||||||||||||||
| masked_context = mask_pattern.sub("___", context) | ||||||||||||||
| else: | ||||||||||||||
| logger.debug( | ||||||||||||||
| "Regex Match Failed!\n" | ||||||||||||||
| "Expected name of node: %s\n" | ||||||||||||||
| "Actual context: %s\n", | ||||||||||||||
| mask_node_name, | ||||||||||||||
| context, | ||||||||||||||
| ) | ||||||||||||||
| return [] | ||||||||||||||
|
|
||||||||||||||
| logger.debug("masked_context: %s", masked_context) | ||||||||||||||
| qa_pairs = { | ||||||||||||||
| "question": masked_context, | ||||||||||||||
| "answer": gth, | ||||||||||||||
| } | ||||||||||||||
| return [qa_pairs] | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import random | ||
| from collections import deque | ||
| from typing import Any, Iterable, Set | ||
|
|
||
| from graphgen.bases import BaseGraphStorage, BasePartitioner | ||
| from graphgen.bases.datatypes import Community | ||
|
|
||
| random.seed(42) | ||
|
|
||
|
|
||
| class QuintuplePartitioner(BasePartitioner): | ||
| """ | ||
| quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node). | ||
| 1. Automatically ignore isolated points. | ||
| 2. In each connected component, yield quintuples in the order of BFS. | ||
| """ | ||
|
|
||
| def partition( | ||
| self, | ||
| g: BaseGraphStorage, | ||
| **kwargs: Any, | ||
| ) -> Iterable[Community]: | ||
| nodes = [n[0] for n in g.get_all_nodes()] | ||
| random.shuffle(nodes) | ||
|
|
||
| visited_nodes: Set[str] = set() | ||
| used_edges: Set[frozenset[str]] = set() | ||
|
|
||
| for seed in nodes: | ||
| if seed in visited_nodes: | ||
| continue | ||
|
|
||
| # start BFS in a connected component | ||
| queue = deque([seed]) | ||
| visited_nodes.add(seed) | ||
|
|
||
| while queue: | ||
| u = queue.popleft() | ||
|
|
||
| # collect all neighbors connected to node u via unused edges | ||
| available_neighbors = [] | ||
| for v in g.get_neighbors(u): | ||
| edge_key = frozenset((u, v)) | ||
| if edge_key not in used_edges: | ||
| available_neighbors.append(v) | ||
|
|
||
| # standard BFS queue maintenance | ||
| if v not in visited_nodes: | ||
| visited_nodes.add(v) | ||
| queue.append(v) | ||
|
|
||
| random.shuffle(available_neighbors) | ||
|
|
||
| # every two neighbors paired with the center node u creates one quintuple | ||
| # Note: If available_neighbors has an odd length, the remaining edge | ||
| # stays unused for now. It may be matched into a quintuple later | ||
| # when its other endpoint is processed as a center node. | ||
| for i in range(0, len(available_neighbors) // 2 * 2, 2): | ||
| v1 = available_neighbors[i] | ||
| v2 = available_neighbors[i + 1] | ||
|
|
||
| edge1 = frozenset((u, v1)) | ||
| edge2 = frozenset((u, v2)) | ||
|
|
||
| used_edges.add(edge1) | ||
| used_edges.add(edge2) | ||
|
|
||
| v1_s, v2_s = sorted((v1, v2)) | ||
|
|
||
| yield Community( | ||
| id=f"{v1_s}-{u}-{v2_s}", | ||
| nodes=[v1_s, u, v2_s], | ||
| edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))], | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import random | ||
| from collections import deque | ||
| from typing import Any, Iterable, Set | ||
|
|
||
| from graphgen.bases import BaseGraphStorage, BasePartitioner | ||
| from graphgen.bases.datatypes import Community | ||
|
|
||
| random.seed(42) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting a global random seed with |
||
|
|
||
|
|
||
| class TriplePartitioner(BasePartitioner): | ||
| """ | ||
| Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node). | ||
| 1. Automatically ignore isolated points. | ||
| 2. In each connected component, yield triples in the order of BFS. | ||
| """ | ||
|
|
||
| def partition( | ||
| self, | ||
| g: BaseGraphStorage, | ||
| **kwargs: Any, | ||
| ) -> Iterable[Community]: | ||
| nodes = [n[0] for n in g.get_all_nodes()] | ||
| random.shuffle(nodes) | ||
|
|
||
| visited_nodes: Set[str] = set() | ||
| used_edges: Set[frozenset[str]] = set() | ||
|
|
||
| for seed in nodes: | ||
| if seed in visited_nodes: | ||
| continue | ||
|
|
||
| # start BFS in a connected component | ||
| queue = deque([seed]) | ||
| visited_nodes.add(seed) | ||
|
|
||
| while queue: | ||
| u = queue.popleft() | ||
|
|
||
| for v in g.get_neighbors(u): | ||
| edge_key = frozenset((u, v)) | ||
|
|
||
| # if this edge has not been used, a new triple has been found | ||
| if edge_key not in used_edges: | ||
| used_edges.add(edge_key) | ||
|
|
||
| # use the edge name to ensure the uniqueness of the ID | ||
| u_sorted, v_sorted = sorted((u, v)) | ||
| yield Community( | ||
| id=f"{u_sorted}-{v_sorted}", | ||
| nodes=[u_sorted, v_sorted], | ||
| edges=[(u_sorted, v_sorted)], | ||
| ) | ||
|
|
||
| # continue to BFS | ||
| if v not in visited_nodes: | ||
| visited_nodes.add(v) | ||
| queue.append(v) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting a global random seed with
random.seed(42)is generally discouraged as it affects the entire application's random number generation, which can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a localrandom.Randominstance within your class, for example in the__init__method, and use that for random operations likerandom.choiceon line 103.