Skip to content

Commit

Permalink
pre-commit done
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Jan 10, 2025
1 parent 48135ab commit 3da5ac5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
1 change: 1 addition & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ process:
sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- generate_qa_from_text_mapper: # mapper to generate question and answer pairs from text.
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # Model name on huggingface to generate question and answer pairs.
max_num: null # The max num of returned QA sample for each text. Not limit if it is None.
output_pattern: null # Regular expression pattern to extract questions and answers from model response.
enable_vllm: false # Whether to use vllm for inference acceleration.
model_params: {} # Parameters for initializing the model.
Expand Down
10 changes: 10 additions & 0 deletions data_juicer/ops/mapper/generate_qa_from_text_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
Expand Down Expand Up @@ -35,6 +36,7 @@ class GenerateQAFromTextMapper(Mapper):

def __init__(self,
hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa',
max_num: Optional[PositiveInt] = None,
*,
output_pattern: Optional[str] = None,
enable_vllm: bool = False,
Expand All @@ -45,6 +47,8 @@ def __init__(self,
Initialization method.
:param hf_model: Hugginface model ID.
:param max_num: The max num of returned QA sample for each text.
Not limit if it is None.
:param output_pattern: Regular expression pattern to extract
questions and answers from model response.
:param enable_vllm: Whether to use vllm for inference acceleration.
Expand All @@ -69,6 +73,8 @@ def __init__(self,

super().__init__(**kwargs)

self.max_num = max_num

if output_pattern is None:
self.output_pattern = r'Human:(.*?)Assistant:(.*?)(?=Human|$)' # noqa: E501
else:
Expand Down Expand Up @@ -131,6 +137,10 @@ def process_batched(self, samples, rank=None):
output = response[0]['generated_text']

qa_list = self.parse_output(output)

if self.max_num is not None:
qa_list = qa_list[:self.max_num]

if len(qa_list) > 0:
for q, a in qa_list:
for input_k in input_keys:
Expand Down
10 changes: 8 additions & 2 deletions tests/ops/mapper/test_generate_qa_from_text_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ def _run_op(self,
enable_vllm=False,
model_params=None,
sampling_params=None,
num_proc=1):
num_proc=1,
max_num=None):

op = GenerateQAFromTextMapper(enable_vllm=enable_vllm,
model_params=model_params,
sampling_params=sampling_params)
sampling_params=sampling_params,
max_num=max_num)

samples = [{
self.text_key:
Expand All @@ -45,6 +47,10 @@ def test(self):
sampling_params = {'max_new_tokens': 200}
self._run_op(sampling_params=sampling_params)

def test_max_num(self):
sampling_params = {'max_new_tokens': 200}
self._run_op(sampling_params=sampling_params, max_num=1)

def test_multi_process(self):
sampling_params = {'max_new_tokens': 200}
self._run_op(sampling_params=sampling_params, num_proc=2)
Expand Down

0 comments on commit 3da5ac5

Please sign in to comment.