blob: 1dd85d12f633c05bcae7f903868f3fdbda232834 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import apache_beam as beam
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain_text_splitters import SentenceTransformersTokenTextSplitter
from apache_beam.transforms import DoFn
from apache_beam.transforms import PTransform
from enum import Enum
__all__ = [
'ChunksGeneration',
'ChunkingStrategy'
]
class ChunkingStrategy(Enum):
SPLIT_BY_CHARACTER = 0
RECURSIVE_SPLIT_BY_CHARACTER = 1
SPLIT_BY_TOKENS = 2
class ChunksGeneration(PTransform):
"""ChunkingStrategy is a ``PTransform`` that takes a ``PCollection`` of
key, value tuple or 2-element array and generates different chunks for documents.
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
chunking_strategy: ChunkingStrategy
):
"""
Args:
chunk_size : Chunk size is the maximum number of characters that a chunk can contain
chunk_overlap : the number of characters that should overlap between two adjacent chunks
chunking_strategy : Defines the way to split text
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.chunking_strategy = chunking_strategy
def expand(self, pcoll):
return pcoll \
| "Generate text chunks" >> beam.ParDo(_GenerateChunksFn(self.chunk_size,
self.chunk_overlap,
self.chunking_strategy))
class _GenerateChunksFn(DoFn):
"""Abstract class that takes in ptransform
and generate chunks.
"""
def __init__(
self,
chunk_size: int,
chunk_overlap: int,
chunking_strategy: ChunkingStrategy
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.chunking_strategy = chunking_strategy
def process(self, element, *args, **kwargs):
# For recursive split by character
if self.chunking_strategy == ChunkingStrategy.RECURSIVE_SPLIT_BY_CHARACTER:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
is_separator_regex=False,
)
# For split by character
elif self.chunking_strategy == ChunkingStrategy.SPLIT_BY_CHARACTER:
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
is_separator_regex=False,
)
# For split by tokens
elif self.chunking_strategy == ChunkingStrategy.SPLIT_BY_TOKENS:
text_splitter = SentenceTransformersTokenTextSplitter(
chunk_overlap=self.chunk_overlap,
model_name='all-MiniLM-L6-v2'
)
else:
raise ValueError(f"Invalid chunking strategy: {self.chunking_strategy}")
texts = text_splitter.split_text(element['text'])[:]
element_copy = element.copy()
del element_copy['text']
for i, section in enumerate(texts):
element_copy['text'] = section
element_copy['section_id'] = i + 1
yield element_copy