Skip to content

Commit 53af270

Browse files
authored
fix: add batching support for BanCompetitors to handle long input text (#272)
* fix: add batching support for BanCompetitors to handle long input text * Fix pre-commit hooks * use normal arguments to ban competitors
1 parent ca87fdc commit 53af270

File tree

2 files changed

+74
-3
lines changed

2 files changed

+74
-3
lines changed

llm_guard/input_scanners/ban_competitors.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import Sequence
3+
import copy
4+
from typing import Any, Sequence
45

56
from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder
67

78
from llm_guard.model import Model
89
from llm_guard.transformers_helpers import get_tokenizer_and_model_for_ner
9-
from llm_guard.util import get_logger, lazy_load_dep
10+
from llm_guard.util import get_logger, lazy_load_dep, split_text_to_word_chunks
1011

1112
from .base import Scanner
1213

@@ -39,6 +40,8 @@ def __init__(
3940
redact: bool = True,
4041
model: Model | None = None,
4142
use_onnx: bool = False,
43+
chunk_size: int = 512,
44+
chunk_overlap_size: int = 40,
4245
) -> None:
4346
"""
4447
Initialize BanCompetitors object.
@@ -59,6 +62,8 @@ def __init__(
5962
self._competitors = competitors
6063
self._threshold = threshold
6164
self._redact = redact
65+
self.chunk_length = chunk_size
66+
self.text_overlap_length = chunk_overlap_size
6267

6368
tf_tokenizer, tf_model = get_tokenizer_and_model_for_ner(
6469
model=model,
@@ -73,7 +78,7 @@ def __init__(
7378
def scan(self, prompt: str) -> tuple[str, bool, float]:
7479
is_detected = False
7580
text_replace_builder = TextReplaceBuilder(original_text=prompt)
76-
entities = self._ner_pipeline(prompt)
81+
entities = self._get_ner_results_for_text(prompt)
7782
assert isinstance(entities, list)
7883
entities = sorted(entities, key=lambda x: x["end"], reverse=True)
7984

@@ -112,3 +117,60 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
112117
LOGGER.debug("None of the competitors were detected")
113118

114119
return prompt, True, -1.0
120+
121+
def _get_ner_results_for_text(self, text: str) -> list[dict]:
122+
"""The function runs model inference on the provided text.
123+
The text is split into chunks with n overlapping characters.
124+
The results are then aggregated and duplications are removed.
125+
126+
:param text: The text to run inference on
127+
:type text: str
128+
:return: List of entity predictions on the word level
129+
:rtype: List[dict]
130+
"""
131+
assert self._ner_pipeline is not None
132+
assert self._ner_pipeline.tokenizer is not None
133+
134+
model_max_length = self._ner_pipeline.tokenizer.model_max_length
135+
# calculate inputs based on the text
136+
# normalize characters to token numbers approximately
137+
# 1 word ~ 2 tokens ~ 4 characters
138+
text_tokens_length = len(text.split()) * 2
139+
chunk_length = self.chunk_length // 2 * 4
140+
text_overlap_length = self.text_overlap_length // 2 * 4
141+
text_length = len(text)
142+
143+
# split text into chunks
144+
if text_tokens_length <= model_max_length:
145+
predictions = self._ner_pipeline(text) # type: ignore
146+
else:
147+
LOGGER.info(
148+
"splitting the text into chunks",
149+
length=text_tokens_length,
150+
model_max_length=model_max_length,
151+
)
152+
predictions: list[dict[str, Any]] = []
153+
chunk_indexes = split_text_to_word_chunks(
154+
text_length, chunk_length, text_overlap_length
155+
)
156+
157+
# iterate over text chunks and run inference
158+
for chunk in chunk_indexes:
159+
chunk_text = text[chunk.start : chunk.end]
160+
chunk_preds = self._ner_pipeline(chunk_text)
161+
162+
assert isinstance(chunk_preds, list)
163+
164+
# align indexes to match the original text - add to each position the value of chunk_start
165+
aligned_predictions: list[dict[str, Any]] = []
166+
for prediction in chunk_preds:
167+
prediction_tmp: dict[str, Any] = copy.deepcopy(prediction)
168+
prediction_tmp["start"] += chunk.start
169+
prediction_tmp["end"] += chunk.start
170+
aligned_predictions.append(prediction_tmp)
171+
172+
predictions.extend(aligned_predictions)
173+
174+
# remove duplicates
175+
predictions = [dict(t) for t in {tuple(d.items()) for d in predictions}]
176+
return predictions

tests/input_scanners/test_ban_competitors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
False,
1616
1.0,
1717
),
18+
(
19+
["Test LLC"],
20+
"Make an SQL insert statement to add a new user to our database. Name is John Doe. Email is [email protected] but also possible to contact him with [email protected] email. Phone number is 555-123-4567 and the IP address is 192.168.1.100. And credit card number is 4567-8901-2345-6789. He works in Test LLC \n"
21+
* 6,
22+
"Make an SQL insert statement to add a new user to our database. Name is John Doe. Email is [email protected] but also possible to contact him with [email protected] email. Phone number is 555-123-4567 and the IP address is 192.168.1.100. And credit card number is 4567-8901-2345-6789. He works in [REDACTED] \n"
23+
* 6,
24+
False,
25+
1.0,
26+
),
1827
],
1928
)
2029
def test_scan(

0 commit comments

Comments
 (0)