Skip to content

Core

novelentitymatcher.core.matcher

Classes

EmbeddingMatcher(entities, model_name='sentence-transformers/paraphrase-mpnet-base-v2', threshold=0.7, normalize=True, embedding_dim=None, cache=None)

Embedding-based similarity matching without training.

Source code in src/novelentitymatcher/core/embedding_matcher.py
def __init__(
    self,
    entities: list[dict[str, Any]],
    model_name: str = "sentence-transformers/paraphrase-mpnet-base-v2",
    threshold: float = 0.7,
    normalize: bool = True,
    embedding_dim: int | None = None,
    cache: ModelCache | None = None,
):
    validate_entities(entities)
    validate_model_name(model_name)

    self.entities = entities
    self.model_name = model_name
    self.threshold = validate_threshold(threshold)
    self.normalize = normalize
    self.embedding_dim = embedding_dim

    self.normalizer = TextNormalizer() if normalize else None
    self.cache = cache if cache is not None else get_default_cache()
    self.model: EmbeddingModel | None = None
    self.entity_texts: list[str] = []
    self.entity_ids: list[str] = []
    self.embeddings: np.ndarray | None = None
    self._async_executor: Any | None = None

Matcher(entities, model='default', threshold=0.7, normalize=True, mode=None, blocking_strategy=None, reranker_model='default', verbose=False, metrics_callback=None)

Unified entity matcher with smart auto-selection.

Automatically chooses the best matching strategy: - No training data -> zero-shot (embedding similarity) - < 3 examples/entity -> head-only training (~30s) - >= 3 examples/entity -> full training (~3min)

Source code in src/novelentitymatcher/core/matcher.py
def __init__(
    self,
    entities: list[dict[str, Any]],
    model: str = "default",
    threshold: float = 0.7,
    normalize: bool = True,
    mode: str | None = None,
    blocking_strategy: Any | None = None,
    reranker_model: str = "default",
    verbose: bool = False,
    metrics_callback: Callable | None = None,
):
    validate_entities(entities)
    validate_threshold(threshold)

    env_verbose = (
        os.getenv("NOVEL_ENTITY_MATCHER_VERBOSE", "false").lower() == "true"
    )
    verbose = verbose or env_verbose

    configure_logging(verbose=verbose)
    self.logger = get_logger(__name__)

    self.entities = entities
    self._runtime_state = MatcherRuntimeState.create(
        model=model,
        threshold=threshold,
        mode=mode,
    )
    self.model_name = self._runtime_state.model_name
    self._requested_model = self._runtime_state.requested_model
    self._training_model_name = self._runtime_state.training_model_name
    self._bert_model_name = self._runtime_state.bert_model_name
    self.threshold = self._runtime_state.threshold
    self.normalize = normalize
    self.mode = mode
    self.blocking_strategy = blocking_strategy
    self.reranker_model = reranker_model
    self._verbose = verbose
    self._metrics_callback = metrics_callback

    self._async_executor: AsyncExecutor | None = None
    self._async_fit_lock = asyncio.Lock()

    self._training_mode = self._runtime_state.training_mode
    self._components = MatcherComponentFactory(self)
    self._has_training_data = self._runtime_state.has_training_data
    self._active_matcher: Any | None = None
    self._detected_mode: str | None = self._runtime_state.detected_mode

    self._hybrid_engine = _HybridEngine(self)
    self._batch_engine = _BatchEngine(self)
    self._diagnosis_engine = _DiagnosisEngine(self)

Functions

novelentitymatcher.core.classifier

Classes

SetFitClassifier(labels, model_name='sentence-transformers/paraphrase-mpnet-base-v2', num_epochs=4, batch_size=16, weight_decay=0.01, head_c=1.0, num_iterations=5, pca_dims=None, skip_body_training=False)

Wrapper for SetFit training and prediction.

Source code in src/novelentitymatcher/core/classifier.py
def __init__(
    self,
    labels: list[str],
    model_name: str = "sentence-transformers/paraphrase-mpnet-base-v2",
    num_epochs: int = 4,
    batch_size: int = 16,
    weight_decay: float = 0.01,
    head_c: float = 1.0,
    num_iterations: int = 5,
    pca_dims: int | None = None,
    skip_body_training: bool = False,
):
    self.labels = labels
    self.model_name = model_name
    self.num_epochs = num_epochs
    self.batch_size = batch_size
    self.weight_decay = weight_decay
    self.head_c = head_c
    self.num_iterations = num_iterations
    self.pca_dims = pca_dims
    self.skip_body_training = skip_body_training
    self.model: Any | None = None
    self.model_head: LogisticRegression | None = None
    self.class_centroids: dict[str, np.ndarray] = {}
    self.pca: Any | None = None
    self.is_trained = False
    self.logger = get_logger(__name__)
    self._use_sentence_transformer_fallback = False
Functions
train(training_data, num_epochs=None, batch_size=None, show_progress=True)

Train the classifier.

Parameters:

Name Type Description Default
training_data list[dict]

List of training examples with 'text' and 'label' keys

required
num_epochs int | None

Number of training epochs (overrides default)

None
batch_size int | None

Batch size for training (overrides default)

None
show_progress bool

Whether to show progress bar during training

True
Source code in src/novelentitymatcher/core/classifier.py
def train(
    self,
    training_data: list[dict],
    num_epochs: int | None = None,
    batch_size: int | None = None,
    show_progress: bool = True,
):
    """Train the classifier.

    Args:
        training_data: List of training examples with 'text' and 'label' keys
        num_epochs: Number of training epochs (overrides default)
        batch_size: Batch size for training (overrides default)
        show_progress: Whether to show progress bar during training
    """
    suppress_third_party_loggers()

    epochs = num_epochs or self.num_epochs
    batch = batch_size or self.batch_size

    texts = [item["text"] for item in training_data]
    labels_arr = np.array([item["label"] for item in training_data])

    if self.skip_body_training:
        self.model = get_cached_sentence_transformer(self.model_name)
        self._use_sentence_transformer_fallback = True
        embeddings = self.model.encode(texts, show_progress_bar=False)
        self._train_fallback_head(embeddings, labels_arr, training_data)
    else:
        try:
            self.model = get_cached_setfit_model(
                self.model_name, labels=self.labels
            )
            self._use_sentence_transformer_fallback = False
            Trainer, TrainingArguments = _load_setfit_trainer_classes()
        except ImportError as exc:
            self.logger.warning(
                "SetFit training unavailable for %s; falling back to "
                "sentence-transformer embeddings + logistic head: %s",
                self.model_name,
                exc,
            )
            self.model = get_cached_sentence_transformer(self.model_name)
            self._use_sentence_transformer_fallback = True
            embeddings = self.model.encode(texts, show_progress_bar=False)
            self._train_fallback_head(embeddings, labels_arr, training_data)
        else:
            dataset = Dataset.from_list(training_data)

            args = TrainingArguments(
                output_dir=tempfile.mkdtemp(prefix="novelentitymatcher-setfit-"),
                num_epochs=epochs,
                batch_size=batch,
                body_learning_rate=2e-5,
                head_learning_rate=1e-3,
                save_strategy="no",
                report_to="none",
                logging_dir=None,
                l2_weight=self.weight_decay,
                num_iterations=self.num_iterations,
            )

            trainer = Trainer(
                model=self.model,
                args=args,
                train_dataset=dataset,
            )

            if show_progress:
                try:
                    from tqdm.auto import tqdm

                    with tqdm(total=epochs, desc="Training", unit="epoch"):
                        trainer.train()
                except ImportError:
                    trainer.train()
            else:
                trainer.train()

            embeddings = self.model.model_body.encode(
                texts, show_progress_bar=False
            )
            self._train_logistic_head(embeddings, labels_arr, training_data)

    self.is_trained = True

Functions

novelentitymatcher.core.normalizer

Classes

TextNormalizer(lowercase=True, remove_accents=False, remove_punctuation=False)

Text normalization utilities for entity matching.

Source code in src/novelentitymatcher/core/normalizer.py
def __init__(
    self,
    lowercase: bool = True,
    remove_accents: bool = False,
    remove_punctuation: bool = False,
):
    self.lowercase = lowercase
    self.remove_accents = remove_accents
    self.remove_punctuation = remove_punctuation

novelentitymatcher.core.reranker

Cross-encoder reranking for semantic entity matching.

Classes

CrossEncoderReranker(model='bge-m3', backend=None, device=None, batch_size=32)

User-facing API for cross-encoder reranking.

Provides precise reranking of candidate entities using cross-encoder models. Typically used after initial retrieval with bi-encoder models.

Example

from novelentitymatcher import EmbeddingMatcher, CrossEncoderReranker

Initial retrieval

retriever = EmbeddingMatcher(entities, model_name="bge-base") retriever.build_index() candidates = retriever.match(query, top_k=50)

Rerank top candidates

reranker = CrossEncoderReranker(model="bge-m3") final_results = reranker.rerank(query, candidates, top_k=5)

Parameters:

Name Type Description Default
model str

Model alias or full model name

'bge-m3'
backend

Custom backend implementation (defaults to STReranker)

None
device str | None

Device to run model on (None for auto-detection)

None
batch_size int

Batch size for inference

32
Source code in src/novelentitymatcher/core/reranker.py
def __init__(
    self,
    model: str = "bge-m3",
    backend=None,
    device: str | None = None,
    batch_size: int = 32,
):
    """
    Initialize the reranker.

    Args:
        model: Model alias or full model name
        backend: Custom backend implementation (defaults to STReranker)
        device: Device to run model on (None for auto-detection)
        batch_size: Batch size for inference
    """
    self.model_name = resolve_reranker_alias(model)

    if backend is None:
        backend = STReranker(
            model_name=self.model_name,
            device=device,
            batch_size=batch_size,
        )

    self.backend = backend
    self.device = device
    self.batch_size = batch_size
Functions
rerank(query, candidates, top_k=5, text_field='text')

Rerank candidates using cross-encoder.

Parameters:

Name Type Description Default
query str

Query text

required
candidates list[dict[str, Any]]

List of candidate dictionaries

required
top_k int

Number of top results to return

5
text_field str

Field name containing text to score

'text'

Returns:

Type Description
list[dict[str, Any]]

Reranked list of candidates (top_k only) with added 'cross_encoder_score' field

Source code in src/novelentitymatcher/core/reranker.py
def rerank(
    self,
    query: str,
    candidates: list[dict[str, Any]],
    top_k: int = 5,
    text_field: str = "text",
) -> list[dict[str, Any]]:
    """
    Rerank candidates using cross-encoder.

    Args:
        query: Query text
        candidates: List of candidate dictionaries
        top_k: Number of top results to return
        text_field: Field name containing text to score

    Returns:
        Reranked list of candidates (top_k only) with added 'cross_encoder_score' field
    """
    if not candidates:
        return []

    return self.backend.rerank(
        query, candidates, top_k=top_k, text_field=text_field
    )
rerank_batch(queries, candidates_list, top_k=5, text_field='text')

Batch reranking for multiple queries.

Parameters:

Name Type Description Default
queries list[str]

List of query texts

required
candidates_list list[list[dict[str, Any]]]

List of candidate lists (one per query)

required
top_k int

Number of top results to return per query

5
text_field str

Field name containing text to score

'text'

Returns:

Type Description
list[list[dict[str, Any]]]

List of reranked candidate lists

Source code in src/novelentitymatcher/core/reranker.py
def rerank_batch(
    self,
    queries: list[str],
    candidates_list: list[list[dict[str, Any]]],
    top_k: int = 5,
    text_field: str = "text",
) -> list[list[dict[str, Any]]]:
    """
    Batch reranking for multiple queries.

    Args:
        queries: List of query texts
        candidates_list: List of candidate lists (one per query)
        top_k: Number of top results to return per query
        text_field: Field name containing text to score

    Returns:
        List of reranked candidate lists
    """
    if len(queries) != len(candidates_list):
        raise ValueError("queries and candidates_list must have the same length")
    return [
        self.backend.rerank(query, cands, top_k=top_k, text_field=text_field)
        for query, cands in zip(queries, candidates_list, strict=False)
    ]
score(query, docs)

Score query-document pairs.

Parameters:

Name Type Description Default
query str

Query text

required
docs list[str]

List of document texts

required

Returns:

Type Description
list[float]

List of scores (one per document)

Source code in src/novelentitymatcher/core/reranker.py
def score(self, query: str, docs: list[str]) -> list[float]:
    """
    Score query-document pairs.

    Args:
        query: Query text
        docs: List of document texts

    Returns:
        List of scores (one per document)
    """
    return self.backend.score(query, docs)

novelentitymatcher.core.hierarchy

Hierarchical entity matching with multi-parent support.

This module provides: - HierarchyIndex: Graph-based hierarchy representation - HierarchicalScoring: Depth-aware confidence scoring - HierarchicalMatcher: User-facing API for hierarchical matching

Classes

HierarchyIndex(entities)

Graph-based index for hierarchical entity relationships.

Supports: - Multi-parent hierarchies (DAG structure) - Weighted edges for relationship strength - Fast ancestor/descendant queries - Path finding and depth calculation

Parameters:

Name Type Description Default
entities list[dict[str, Any]]

List of entity dicts with optional 'hierarchy' key hierarchy format: { 'parents': ['parent_id1', 'parent_id2'], 'children': ['child_id1', 'child_id2'], 'level': int, 'weights': {'parent_id': float} }

required
Source code in src/novelentitymatcher/core/hierarchy.py
def __init__(self, entities: list[dict[str, Any]]):
    """
    Build hierarchy index from entity definitions.

    Args:
        entities: List of entity dicts with optional 'hierarchy' key
                 hierarchy format: {
                     'parents': ['parent_id1', 'parent_id2'],
                     'children': ['child_id1', 'child_id2'],
                     'level': int,
                     'weights': {'parent_id': float}
                 }
    """
    self.entities = {e["id"]: e for e in entities}
    self.graph: Any = nx.DiGraph()
    self._build_graph()
    self._cache: dict[str, Any] = {}
Functions
get_ancestors(entity_id, max_depth=None)

Get all ancestor entities for a given entity.

Parameters:

Name Type Description Default
entity_id str

Entity to find ancestors for

required
max_depth int | None

Maximum depth to traverse (None = unlimited)

None

Returns:

Type Description
list[str]

List of ancestor entity IDs

Source code in src/novelentitymatcher/core/hierarchy.py
def get_ancestors(self, entity_id: str, max_depth: int | None = None) -> list[str]:
    """
    Get all ancestor entities for a given entity.

    Args:
        entity_id: Entity to find ancestors for
        max_depth: Maximum depth to traverse (None = unlimited)

    Returns:
        List of ancestor entity IDs
    """
    return self._bfs_traverse(entity_id, max_depth, self.graph.predecessors)
get_descendants(entity_id, max_depth=None)

Get all descendant entities for a given entity.

Parameters:

Name Type Description Default
entity_id str

Entity to find descendants for

required
max_depth int | None

Maximum depth to traverse (None = unlimited)

None

Returns:

Type Description
list[str]

List of descendant entity IDs

Source code in src/novelentitymatcher/core/hierarchy.py
def get_descendants(
    self, entity_id: str, max_depth: int | None = None
) -> list[str]:
    """
    Get all descendant entities for a given entity.

    Args:
        entity_id: Entity to find descendants for
        max_depth: Maximum depth to traverse (None = unlimited)

    Returns:
        List of descendant entity IDs
    """
    return self._bfs_traverse(entity_id, max_depth, self.graph.successors)
get_relationship_depth(entity_a, entity_b)

Calculate the depth of relationship between two entities.

Parameters:

Name Type Description Default
entity_a str

First entity ID

required
entity_b str

Second entity ID

required

Returns:

Type Description
int

Depth (0 = same entity, 1 = direct parent/child, 2 = grandparent, etc.)

int

Returns -1 if no relationship found

Source code in src/novelentitymatcher/core/hierarchy.py
def get_relationship_depth(self, entity_a: str, entity_b: str) -> int:
    """
    Calculate the depth of relationship between two entities.

    Args:
        entity_a: First entity ID
        entity_b: Second entity ID

    Returns:
        Depth (0 = same entity, 1 = direct parent/child, 2 = grandparent, etc.)
        Returns -1 if no relationship found
    """
    if entity_a == entity_b:
        return 0

    if entity_a not in self.graph or entity_b not in self.graph:
        return -1

    try:
        # Try to find shortest path in the directed graph
        path = nx.shortest_path(self.graph, entity_a, entity_b)
        return len(path) - 1
    except nx.NetworkXNoPath:
        # Try reverse direction (child to parent)
        try:
            path = nx.shortest_path(self.graph, entity_b, entity_a)
            return len(path) - 1
        except nx.NetworkXNoPath:
            return -1
get_path(from_entity, to_entity)

Get shortest path between two entities in the hierarchy.

Parameters:

Name Type Description Default
from_entity str

Starting entity ID

required
to_entity str

Ending entity ID

required

Returns:

Type Description
list[str]

List of entity IDs representing the path (inclusive)

list[str]

Returns empty list if no path exists

Source code in src/novelentitymatcher/core/hierarchy.py
def get_path(self, from_entity: str, to_entity: str) -> list[str]:
    """
    Get shortest path between two entities in the hierarchy.

    Args:
        from_entity: Starting entity ID
        to_entity: Ending entity ID

    Returns:
        List of entity IDs representing the path (inclusive)
        Returns empty list if no path exists
    """
    try:
        return nx.shortest_path(self.graph, from_entity, to_entity)
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        return []
is_ancestor(ancestor_id, descendant_id)

Check if ancestor_id is an ancestor of descendant_id.

Parameters:

Name Type Description Default
ancestor_id str

Potential ancestor

required
descendant_id str

Potential descendant

required

Returns:

Type Description
bool

True if ancestor_id is an ancestor of descendant_id

Source code in src/novelentitymatcher/core/hierarchy.py
def is_ancestor(self, ancestor_id: str, descendant_id: str) -> bool:
    """
    Check if ancestor_id is an ancestor of descendant_id.

    Args:
        ancestor_id: Potential ancestor
        descendant_id: Potential descendant

    Returns:
        True if ancestor_id is an ancestor of descendant_id
    """
    if ancestor_id == descendant_id:
        return False

    ancestors = self.get_ancestors(descendant_id)
    return ancestor_id in ancestors

HierarchicalScoring(hierarchy_index, alpha=0.7, beta=0.3)

Calculate hierarchy-aware confidence scores.

Combines: - Semantic similarity (cosine similarity of embeddings) - Hierarchical proximity boost (based on relationship type) - Depth penalty (deeper relationships = lower scores)

Parameters:

Name Type Description Default
hierarchy_index HierarchyIndex

HierarchyIndex for graph operations

required
alpha float

Weight for semantic similarity (0-1)

0.7
beta float

Weight for hierarchical boost (0-1)

0.3
Source code in src/novelentitymatcher/core/hierarchy.py
def __init__(
    self, hierarchy_index: HierarchyIndex, alpha: float = 0.7, beta: float = 0.3
):
    """
    Initialize hierarchical scorer.

    Args:
        hierarchy_index: HierarchyIndex for graph operations
        alpha: Weight for semantic similarity (0-1)
        beta: Weight for hierarchical boost (0-1)
    """
    self.hierarchy = hierarchy_index
    self.alpha = alpha
    self.beta = beta
Functions
compute_score(query_embedding, entity_embedding, entity_id, relationship_type='self', depth=0)

Compute hierarchical score combining semantic and hierarchical features.

Formula

final_score = ( semantic_similarity * alpha + hierarchical_boost * beta ) * depth_penalty

Parameters:

Name Type Description Default
query_embedding ndarray

Query text embedding

required
entity_embedding ndarray

Entity text embedding

required
entity_id str

Entity identifier

required
relationship_type str

"self", "parent", "child", "ancestor", "descendant"

'self'
depth int

Relationship depth (0=self, 1=direct, etc.)

0

Returns:

Type Description
float

Final hierarchical score (0-1)

Source code in src/novelentitymatcher/core/hierarchy.py
def compute_score(
    self,
    query_embedding: np.ndarray,
    entity_embedding: np.ndarray,
    entity_id: str,
    relationship_type: str = "self",
    depth: int = 0,
) -> float:
    """
    Compute hierarchical score combining semantic and hierarchical features.

    Formula:
        final_score = (
            semantic_similarity * alpha +
            hierarchical_boost * beta
        ) * depth_penalty

    Args:
        query_embedding: Query text embedding
        entity_embedding: Entity text embedding
        entity_id: Entity identifier
        relationship_type: "self", "parent", "child", "ancestor", "descendant"
        depth: Relationship depth (0=self, 1=direct, etc.)

    Returns:
        Final hierarchical score (0-1)
    """
    # Compute semantic similarity
    semantic_score = self._compute_semantic_similarity(
        query_embedding, entity_embedding
    )

    # Get hierarchical boost for this relationship type
    hierarchical_boost = self._get_hierarchical_boost(relationship_type)

    # Get depth penalty
    depth_penalty = self.DEPTH_PENALTIES.get(depth, 0.4)

    # Combine scores
    final_score = (
        semantic_score * self.alpha + hierarchical_boost * self.beta
    ) * depth_penalty

    return float(final_score)

HierarchicalMatcher(entities, embedding_model='BAAI/bge-base-en-v1.5', alpha=0.7, beta=0.3, normalize=True)

Hierarchical entity matching with multi-parent support.

Combines semantic similarity (via EmbeddingMatcher) with hierarchy-aware scoring to enable flexible granularity matching.

Features: - Match at any level in hierarchy (self, ancestors, descendants) - Multi-parent hierarchy support - Depth-aware confidence scores - Flexible granularity matching

Parameters:

Name Type Description Default
entities list[dict[str, Any]]

List of entity dicts with optional 'hierarchy' key

required
embedding_model str

Sentence transformer model name

'BAAI/bge-base-en-v1.5'
alpha float

Weight for semantic similarity (0-1)

0.7
beta float

Weight for hierarchical boost (0-1)

0.3
normalize bool

Whether to apply text normalization

True
Source code in src/novelentitymatcher/core/hierarchy.py
def __init__(
    self,
    entities: list[dict[str, Any]],
    embedding_model: str = "BAAI/bge-base-en-v1.5",
    alpha: float = 0.7,
    beta: float = 0.3,
    normalize: bool = True,
):
    """
    Initialize hierarchical matcher.

    Args:
        entities: List of entity dicts with optional 'hierarchy' key
        embedding_model: Sentence transformer model name
        alpha: Weight for semantic similarity (0-1)
        beta: Weight for hierarchical boost (0-1)
        normalize: Whether to apply text normalization
    """
    self.entities = entities
    self.entities_dict = {e["id"]: e for e in entities}
    self.embedding_model = embedding_model
    self.normalize = normalize

    # Initialize text normalizer
    self.normalizer = TextNormalizer() if normalize else None

    # Build hierarchy index
    self.hierarchy_index = HierarchyIndex(entities)

    # Initialize scorer
    self.scorer = HierarchicalScoring(self.hierarchy_index, alpha=alpha, beta=beta)

    # Will be initialized in build_index()
    self.embedding_matcher: Any = None
    self.entity_embeddings: dict[str, Any] = {}
    self.entity_texts: dict[str, str] = {}
Functions
build_index()

Build embedding index for all entities.

Must be called before matching.

Source code in src/novelentitymatcher/core/hierarchy.py
def build_index(self):
    """
    Build embedding index for all entities.

    Must be called before matching.
    """
    # Prepare entity texts (name + aliases)
    for entity in self.entities:
        entity_id = entity["id"]
        texts = [entity["name"]]

        if "aliases" in entity:
            texts.extend(entity["aliases"])

        # Apply normalization if enabled
        if self.normalizer:
            texts = [self.normalizer.normalize(t) for t in texts]

        # Store combined text (join with space)
        self.entity_texts[entity_id] = " ".join(texts)

    # Create EmbeddingMatcher for semantic similarity
    embedding_entities = [
        {"id": eid, "name": text} for eid, text in self.entity_texts.items()
    ]

    self.embedding_matcher = EmbeddingMatcher(
        entities=embedding_entities,
        model_name=self.embedding_model,
        normalize=False,  # Already normalized if needed
    )

    self.embedding_matcher.build_index()

    # Cache embeddings for scoring - create dict mapping entity_id to embedding
    self.entity_embeddings = {
        entity_id: self.embedding_matcher.embeddings[idx]
        for idx, entity_id in enumerate(self.embedding_matcher.entity_ids)
    }
match(query, top_k=5, match_level='all', max_depth=3)

Match query considering hierarchical relationships.

Parameters:

Name Type Description Default
query str

Query text

required
top_k int

Number of results to return

5
match_level str

"self", "ancestors", "descendants", "all"

'all'
max_depth int

Maximum depth to traverse for hierarchical matches

3

Returns:

Type Description
list[dict[str, Any]]

List of matches with:

list[dict[str, Any]]
  • id: Entity ID
list[dict[str, Any]]
  • score: Final hierarchical score
list[dict[str, Any]]
  • relationship: "self", "parent", "child", "ancestor", "descendant"
list[dict[str, Any]]
  • depth: Relationship depth
list[dict[str, Any]]
  • semantic_score: Raw embedding similarity
list[dict[str, Any]]
  • hierarchical_boost: Applied hierarchical boost
Source code in src/novelentitymatcher/core/hierarchy.py
def match(
    self, query: str, top_k: int = 5, match_level: str = "all", max_depth: int = 3
) -> list[dict[str, Any]]:
    """
    Match query considering hierarchical relationships.

    Args:
        query: Query text
        top_k: Number of results to return
        match_level: "self", "ancestors", "descendants", "all"
        max_depth: Maximum depth to traverse for hierarchical matches

    Returns:
        List of matches with:
        - id: Entity ID
        - score: Final hierarchical score
        - relationship: "self", "parent", "child", "ancestor", "descendant"
        - depth: Relationship depth
        - semantic_score: Raw embedding similarity
        - hierarchical_boost: Applied hierarchical boost
    """
    if self.embedding_matcher is None:
        raise RuntimeError("Must call build_index() before matching")

    # Normalize query if needed
    if self.normalizer:
        query = self.normalizer.normalize(query)

    # Get query embedding
    query_emb = self.embedding_matcher.model.encode(query, convert_to_numpy=True)

    # Collect candidates based on match_level
    candidates = []

    # Get base matches from embedding matcher (self-level)
    base_matches = self.embedding_matcher.match(query, top_k=top_k * 2)

    for base_match in base_matches:
        entity_id = base_match["id"]
        entity_emb = self.entity_embeddings[entity_id]

        # Compute hierarchical score for self-match
        score = self.scorer.compute_score(
            query_emb, entity_emb, entity_id, relationship_type="self", depth=0
        )

        candidates.append(
            {
                "id": entity_id,
                "score": score,
                "relationship": "self",
                "depth": 0,
                "semantic_score": base_match["score"],
                "hierarchical_boost": 0.0,
            }
        )

    # Add hierarchical matches if requested
    if match_level in ["ancestors", "all"]:
        for base_match in base_matches[:top_k]:  # Only check top matches
            entity_id = base_match["id"]
            ancestors = self.hierarchy_index.get_ancestors(entity_id, max_depth)

            for ancestor_id in ancestors:
                if ancestor_id not in self.entity_embeddings:
                    continue

                # Calculate depth
                depth = self.hierarchy_index.get_relationship_depth(
                    entity_id, ancestor_id
                )

                ancestor_emb = self.entity_embeddings[ancestor_id]

                score = self.scorer.compute_score(
                    query_emb,
                    ancestor_emb,
                    ancestor_id,
                    relationship_type="ancestor" if depth > 1 else "parent",
                    depth=depth,
                )

                candidates.append(
                    {
                        "id": ancestor_id,
                        "score": score,
                        "relationship": "parent" if depth == 1 else "ancestor",
                        "depth": depth,
                        "semantic_score": float(
                            cosine_similarity(
                                query_emb.reshape(1, -1),
                                ancestor_emb.reshape(1, -1),
                            )[0][0]
                        ),
                        "hierarchical_boost": self.scorer._get_hierarchical_boost(
                            "parent" if depth == 1 else "ancestor"
                        ),
                    }
                )

    if match_level in ["descendants", "all"]:
        for base_match in base_matches[:top_k]:
            entity_id = base_match["id"]
            descendants = self.hierarchy_index.get_descendants(entity_id, max_depth)

            for descendant_id in descendants:
                if descendant_id not in self.entity_embeddings:
                    continue

                depth = self.hierarchy_index.get_relationship_depth(
                    entity_id, descendant_id
                )

                descendant_emb = self.entity_embeddings[descendant_id]

                score = self.scorer.compute_score(
                    query_emb,
                    descendant_emb,
                    descendant_id,
                    relationship_type="descendant" if depth > 1 else "child",
                    depth=depth,
                )

                candidates.append(
                    {
                        "id": descendant_id,
                        "score": score,
                        "relationship": "child" if depth == 1 else "descendant",
                        "depth": depth,
                        "semantic_score": float(
                            cosine_similarity(
                                query_emb.reshape(1, -1),
                                descendant_emb.reshape(1, -1),
                            )[0][0]
                        ),
                        "hierarchical_boost": self.scorer._get_hierarchical_boost(
                            "child" if depth == 1 else "descendant"
                        ),
                    }
                )

    # Remove duplicates (keep highest score)
    seen: dict[str, dict[str, Any]] = {}
    for candidate in candidates:
        cid = candidate["id"]
        if cid not in seen or candidate["score"] > seen[cid]["score"]:
            seen[cid] = candidate

    # Sort by score and return top_k
    results = sorted(seen.values(), key=lambda x: x["score"], reverse=True)
    return results[:top_k]
get_ancestors(entity_id, max_depth=None)

Get all ancestors of an entity with metadata.

Parameters:

Name Type Description Default
entity_id str

Entity to find ancestors for

required
max_depth int | None

Maximum depth to traverse

None

Returns:

Type Description
list[dict[str, Any]]

List of ancestor entities with metadata

Source code in src/novelentitymatcher/core/hierarchy.py
def get_ancestors(
    self, entity_id: str, max_depth: int | None = None
) -> list[dict[str, Any]]:
    """
    Get all ancestors of an entity with metadata.

    Args:
        entity_id: Entity to find ancestors for
        max_depth: Maximum depth to traverse

    Returns:
        List of ancestor entities with metadata
    """
    ancestor_ids = self.hierarchy_index.get_ancestors(entity_id, max_depth)

    return [
        {
            "id": aid,
            "name": self.entities_dict[aid].get("name", aid),
            "depth": self.hierarchy_index.get_relationship_depth(entity_id, aid),
        }
        for aid in ancestor_ids
        if aid in self.entities_dict
    ]
get_descendants(entity_id, max_depth=None)

Get all descendants of an entity with metadata.

Parameters:

Name Type Description Default
entity_id str

Entity to find descendants for

required
max_depth int | None

Maximum depth to traverse

None

Returns:

Type Description
list[dict[str, Any]]

List of descendant entities with metadata

Source code in src/novelentitymatcher/core/hierarchy.py
def get_descendants(
    self, entity_id: str, max_depth: int | None = None
) -> list[dict[str, Any]]:
    """
    Get all descendants of an entity with metadata.

    Args:
        entity_id: Entity to find descendants for
        max_depth: Maximum depth to traverse

    Returns:
        List of descendant entities with metadata
    """
    descendant_ids = self.hierarchy_index.get_descendants(entity_id, max_depth)

    return [
        {
            "id": did,
            "name": self.entities_dict[did].get("name", did),
            "depth": self.hierarchy_index.get_relationship_depth(entity_id, did),
        }
        for did in descendant_ids
        if did in self.entities_dict
    ]
get_hierarchy_path(entity_id, to_entity=None)

Get path from entity_id to root or to_entity.

Parameters:

Name Type Description Default
entity_id str

Starting entity

required
to_entity str | None

Ending entity (None = path to root)

None

Returns:

Type Description
list[dict[str, Any]]

List of entities representing the path

Source code in src/novelentitymatcher/core/hierarchy.py
def get_hierarchy_path(
    self, entity_id: str, to_entity: str | None = None
) -> list[dict[str, Any]]:
    """
    Get path from entity_id to root or to_entity.

    Args:
        entity_id: Starting entity
        to_entity: Ending entity (None = path to root)

    Returns:
        List of entities representing the path
    """
    if to_entity:
        # Try direct path first
        path_ids = self.hierarchy_index.get_path(entity_id, to_entity)

        # If no direct path, try reverse (going up the hierarchy)
        if not path_ids:
            path_ids = self.hierarchy_index.get_path(to_entity, entity_id)
            path_ids = list(reversed(path_ids))
    else:
        # Path to root (farthest ancestor)
        path_ids = [entity_id]
        current = entity_id
        while True:
            ancestors = self.hierarchy_index.get_ancestors(current, max_depth=1)
            if not ancestors:
                break
            current = ancestors[0]
            path_ids.append(current)

    return [
        {"id": pid, "name": self.entities_dict[pid].get("name", pid)}
        for pid in path_ids
        if pid in self.entities_dict
    ]

novelentitymatcher.core.blocking

Blocking strategies for efficient candidate filtering.

Classes

BlockingStrategy

Bases: ABC

Abstract base class for blocking strategies.

Functions
block(query, entities, top_k) abstractmethod

Return top_k candidate entities for the query.

Parameters:

Name Type Description Default
query str

Query text

required
entities list[dict[str, Any]]

List of all entities

required
top_k int

Maximum number of candidates to return

required

Returns:

Type Description
list[dict[str, Any]]

List of candidate entities (top_k or fewer)

Source code in src/novelentitymatcher/core/blocking.py
@abstractmethod
def block(
    self, query: str, entities: list[dict[str, Any]], top_k: int
) -> list[dict[str, Any]]:
    """
    Return top_k candidate entities for the query.

    Args:
        query: Query text
        entities: List of all entities
        top_k: Maximum number of candidates to return

    Returns:
        List of candidate entities (top_k or fewer)
    """

NoOpBlocking

Bases: BlockingStrategy

Pass-through blocking for small datasets.

Returns all entities up to top_k without any filtering.

Functions
block(query, entities, top_k)

Return all entities or top_k if smaller.

Source code in src/novelentitymatcher/core/blocking.py
def block(
    self, query: str, entities: list[dict[str, Any]], top_k: int
) -> list[dict[str, Any]]:
    """Return all entities or top_k if smaller."""
    if len(entities) <= top_k:
        return entities
    return entities[:top_k]

BM25Blocking(k1=1.5, b=0.75)

Bases: BlockingStrategy

Fast lexical blocking using BM25.

Uses BM25 algorithm for efficient lexical matching. Good for keyword-heavy queries and proper nouns.

Parameters:

Name Type Description Default
k1 float

BM25 k1 parameter (term frequency saturation)

1.5
b float

BM25 b parameter (length normalization)

0.75
Source code in src/novelentitymatcher/core/blocking.py
def __init__(self, k1: float = 1.5, b: float = 0.75):
    """
    Initialize BM25 blocking.

    Args:
        k1: BM25 k1 parameter (term frequency saturation)
        b: BM25 b parameter (length normalization)
    """
    self.k1 = k1
    self.b = b
    self.bm25: BM25Okapi | None = None
    self.cached_entities: list[dict[str, Any]] | None = None
    self._entity_hash: str | None = None
Functions
build_index(entities)

Build BM25 index from entities.

Source code in src/novelentitymatcher/core/blocking.py
def build_index(self, entities: list[dict[str, Any]]):
    """Build BM25 index from entities."""
    self.cached_entities = entities
    self._entity_hash = _compute_entity_hash(entities)

    tokenized_corpus = [
        self._tokenize(e.get("text", e.get("name", ""))) for e in entities
    ]
    self.bm25 = BM25Okapi(tokenized_corpus, k1=self.k1, b=self.b)
block(query, entities, top_k)

Return top_k candidates using BM25 scores.

Source code in src/novelentitymatcher/core/blocking.py
def block(
    self, query: str, entities: list[dict[str, Any]], top_k: int
) -> list[dict[str, Any]]:
    """Return top_k candidates using BM25 scores."""
    current_hash = _compute_entity_hash(entities)

    if self.bm25 is None or self._entity_hash != current_hash:
        self.build_index(entities)

    tokenized_query = self._tokenize(query)
    assert self.bm25 is not None
    scores = self.bm25.get_scores(tokenized_query)

    # Get top_k indices
    top_k = min(top_k, len(scores))
    top_indices = np.argsort(scores)[-top_k:][::-1]

    return [entities[i] for i in top_indices]

TFIDFBlocking()

Bases: BlockingStrategy

TF-IDF based blocking.

Uses TF-IDF vectorization for lexical matching. Good for document-level similarity.

Optimized with: - Vocabulary caching across rebuilds - Efficient content-based hashing (MD5) - Sparse matrix operations via sklearn

Source code in src/novelentitymatcher/core/blocking.py
def __init__(self):
    """Initialize TF-IDF blocking."""
    self.vectorizer: TfidfVectorizer | None = None
    self.matrix: Any | None = None
    self.cached_entities: list[dict[str, Any]] | None = None
    self._entity_hash: str | None = None
    self._vocabulary: dict[str, int] | None = None
Functions
build_index(entities)

Build TF-IDF index from entities.

Source code in src/novelentitymatcher/core/blocking.py
def build_index(self, entities: list[dict[str, Any]]):
    """Build TF-IDF index from entities."""
    self.cached_entities = entities
    self._entity_hash = _compute_entity_hash(entities)

    texts = [e.get("text", e.get("name", "")) for e in entities]

    if self._vocabulary is None:
        self.vectorizer = TfidfVectorizer()
        self.matrix = self.vectorizer.fit_transform(texts)
        self._vocabulary = self.vectorizer.vocabulary_
    else:
        self.vectorizer = TfidfVectorizer(vocabulary=self._vocabulary)
        self.matrix = self.vectorizer.fit_transform(texts)
block(query, entities, top_k)

Return top_k candidates using TF-IDF scores.

Source code in src/novelentitymatcher/core/blocking.py
def block(
    self, query: str, entities: list[dict[str, Any]], top_k: int
) -> list[dict[str, Any]]:
    """Return top_k candidates using TF-IDF scores."""
    current_hash = _compute_entity_hash(entities)

    if self.vectorizer is None or self._entity_hash != current_hash:
        self.build_index(entities)

    assert self.vectorizer is not None
    query_vec = self.vectorizer.transform([query])
    assert self.matrix is not None
    scores = (self.matrix @ query_vec.T).toarray().flatten()

    top_k = min(top_k, len(scores))
    top_indices = np.argpartition(scores, -top_k)[-top_k:]
    top_indices = top_indices[np.argsort(scores[top_indices])[::-1]]

    return [entities[i] for i in top_indices]

FuzzyBlocking(score_cutoff=70)

Bases: BlockingStrategy

Fuzzy string matching blocking.

Uses RapidFuzz for approximate string matching. Good for catching typos and variations.

Parameters:

Name Type Description Default
score_cutoff int

Minimum similarity score (0-100)

70
Source code in src/novelentitymatcher/core/blocking.py
def __init__(self, score_cutoff: int = 70):
    """
    Initialize fuzzy blocking.

    Args:
        score_cutoff: Minimum similarity score (0-100)
    """
    self.score_cutoff = score_cutoff
Functions
block(query, entities, top_k)

Return top_k candidates using fuzzy matching.

Source code in src/novelentitymatcher/core/blocking.py
def block(
    self, query: str, entities: list[dict[str, Any]], top_k: int
) -> list[dict[str, Any]]:
    """Return top_k candidates using fuzzy matching."""
    texts = [e.get("text", e.get("name", "")) for e in entities]

    # Extract top matches with indices
    # process.extract returns list of (match, score, index) tuples
    results = process.extract(
        query, texts, scorer=fuzz.token_sort_ratio, limit=top_k
    )

    # Filter by score cutoff, preserving indices
    filtered = [
        (text, score, idx)
        for text, score, idx in results
        if score >= self.score_cutoff
    ]

    # Return matching entities using correct indices
    return [entities[idx] for _, _, idx in filtered]

novelentitymatcher.core.embedding_matcher

Classes

EmbeddingMatcher(entities, model_name='sentence-transformers/paraphrase-mpnet-base-v2', threshold=0.7, normalize=True, embedding_dim=None, cache=None)

Embedding-based similarity matching without training.

Source code in src/novelentitymatcher/core/embedding_matcher.py
def __init__(
    self,
    entities: list[dict[str, Any]],
    model_name: str = "sentence-transformers/paraphrase-mpnet-base-v2",
    threshold: float = 0.7,
    normalize: bool = True,
    embedding_dim: int | None = None,
    cache: ModelCache | None = None,
):
    validate_entities(entities)
    validate_model_name(model_name)

    self.entities = entities
    self.model_name = model_name
    self.threshold = validate_threshold(threshold)
    self.normalize = normalize
    self.embedding_dim = embedding_dim

    self.normalizer = TextNormalizer() if normalize else None
    self.cache = cache if cache is not None else get_default_cache()
    self.model: EmbeddingModel | None = None
    self.entity_texts: list[str] = []
    self.entity_ids: list[str] = []
    self.embeddings: np.ndarray | None = None
    self._async_executor: Any | None = None

Functions

novelentitymatcher.core.bert_classifier

BERT-based classifier using transformers library.

This module provides BERTClassifier, a drop-in alternative to SetFitClassifier that uses fine-tuned BERT models for text classification. BERT classifiers provide superior accuracy for complex pattern-driven tasks but with higher computational cost.

Classes

BERTClassifier(labels, model_name='distilbert-base-uncased', num_epochs=3, batch_size=16, learning_rate=2e-05, max_length=128, use_fp16=True)

BERT-based text classifier using transformers library.

This classifier provides a drop-in alternative to SetFitClassifier with identical interface. It uses fine-tuned BERT models for classification, offering superior accuracy for complex pattern-driven tasks.

Example

from novelentitymatcher.core.bert_classifier import BERTClassifier labels = ["DE", "FR", "US"] clf = BERTClassifier(labels=labels, model_name="distilbert-base-uncased") training_data = [ ... {"text": "Germany", "label": "DE"}, ... {"text": "France", "label": "FR"}, ... {"text": "USA", "label": "US"}, ... ] clf.train(training_data, num_epochs=3) prediction = clf.predict("Deutschland") # "DE" proba = clf.predict_proba("Deutschland") # [0.02, 0.01, 0.97]

Parameters:

Name Type Description Default
labels list[str]

List of class labels for classification.

required
model_name str

HuggingFace model name or path. Default: "distilbert-base-uncased".

'distilbert-base-uncased'
num_epochs int

Number of training epochs. Default: 3.

3
batch_size int

Training batch size. Default: 16.

16
learning_rate float

Learning rate for training. Default: 2e-5.

2e-05
max_length int

Maximum sequence length for tokenization. Default: 128.

128
use_fp16 bool

Whether to use mixed precision training (faster, less memory). Only works on GPU. Default: True.

True
Source code in src/novelentitymatcher/core/bert_classifier.py
def __init__(
    self,
    labels: list[str],
    model_name: str = "distilbert-base-uncased",
    num_epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 2e-5,
    max_length: int = 128,
    use_fp16: bool = True,
):
    """Initialize BERTClassifier.

    Args:
        labels: List of class labels for classification.
        model_name: HuggingFace model name or path. Default: "distilbert-base-uncased".
        num_epochs: Number of training epochs. Default: 3.
        batch_size: Training batch size. Default: 16.
        learning_rate: Learning rate for training. Default: 2e-5.
        max_length: Maximum sequence length for tokenization. Default: 128.
        use_fp16: Whether to use mixed precision training (faster, less memory).
            Only works on GPU. Default: True.
    """
    if not TRANSFORMERS_AVAILABLE:
        raise ImportError(
            "transformers is required for BERTClassifier. "
            "Install with: pip install transformers torch"
        )

    self.labels = labels
    self.label2id = {label: idx for idx, label in enumerate(labels)}
    self.id2label = {idx: label for label, idx in self.label2id.items()}
    self.model_name = model_name
    self.num_epochs = num_epochs
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.max_length = max_length
    self.use_fp16 = use_fp16

    self.model: Any | None = None
    self.tokenizer: Any | None = None
    self.is_trained = False
    self.logger = get_logger(__name__)
Functions
train(training_data, num_epochs=None, batch_size=None, show_progress=True)

Train the BERT classifier.

Parameters:

Name Type Description Default
training_data list[dict]

List of training examples with 'text' and 'label' keys.

required
num_epochs int | None

Number of training epochs (overrides default).

None
batch_size int | None

Batch size for training (overrides default).

None
show_progress bool

Whether to show progress bar during training.

True

Raises:

Type Description
TrainingError

If training fails or data is invalid.

Source code in src/novelentitymatcher/core/bert_classifier.py
def train(
    self,
    training_data: list[dict],
    num_epochs: int | None = None,
    batch_size: int | None = None,
    show_progress: bool = True,
):
    """Train the BERT classifier.

    Args:
        training_data: List of training examples with 'text' and 'label' keys.
        num_epochs: Number of training epochs (overrides default).
        batch_size: Batch size for training (overrides default).
        show_progress: Whether to show progress bar during training.

    Raises:
        TrainingError: If training fails or data is invalid.
    """
    # Suppress third-party library logs
    suppress_third_party_loggers()

    epochs = num_epochs or self.num_epochs
    batch = batch_size or self.batch_size

    # Initialize tokenizer and model
    try:
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=True
        )
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=len(self.labels),
            id2label=self.id2label,
            label2id=self.label2id,
        )
    except (OSError, ValueError, KeyError, RuntimeError) as e:
        raise TrainingError(
            f"Failed to load model/tokenizer: {e}",
            details={"model_name": self.model_name},
        ) from e

    # Prepare dataset
    try:
        dataset = Dataset.from_list(training_data)

        # Tokenize data
        tokenizer = self.tokenizer

        def tokenize_function(examples):
            return tokenizer(
                examples["text"],
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
            )

        tokenized_dataset = dataset.map(tokenize_function, batched=True)

        # Convert string labels to numeric IDs
        def format_labels(example):
            example["label"] = self.label2id[example["label"]]
            return example

        tokenized_dataset = tokenized_dataset.map(format_labels)

        # Remove text column as it's not needed for training
        tokenized_dataset = tokenized_dataset.remove_columns(["text"])
        tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

        # Set format for PyTorch
        tokenized_dataset.set_format("torch")

    except (OSError, ValueError, KeyError, RuntimeError) as e:
        raise TrainingError(
            f"Failed to prepare training data: {e}",
            details={"num_examples": len(training_data)},
        ) from e

    # Determine if we should use fp16 (disable for MPS due to compatibility)
    use_fp16 = self.use_fp16
    if use_fp16:
        try:
            import torch

            # Disable fp16 on MPS (Apple Silicon) due to PyTorch version requirements
            if torch.backends.mps.is_available():
                import warnings

                warnings.warn(
                    "Disabling fp16 on MPS (Apple Silicon) due to compatibility. "
                    "This may slightly slow down training but will not affect accuracy.",
                    stacklevel=2,
                )
                use_fp16 = False
        except ImportError:
            use_fp16 = False

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f".tmp/bert_classifier_{id(self)}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch,
        learning_rate=self.learning_rate,
        weight_decay=0.01,
        logging_dir=None,  # Suppress transformer logs
        logging_steps=50,
        save_strategy="no",  # Don't save checkpoints during training
        report_to="none",  # Disable wandb/tensorboard
        fp16=use_fp16,
        load_best_model_at_end=False,
    )

    # Initialize trainer
    trainer = Trainer(
        model=self.model,
        args=training_args,
        train_dataset=tokenized_dataset,
    )

    # Train with optional progress tracking
    use_tqdm = False
    if show_progress:
        try:
            from tqdm.auto import tqdm

            use_tqdm = True
        except ImportError:
            # tqdm not available, training will be silent
            pass

    if use_tqdm:
        # Wrap training with tqdm progress bar
        with tqdm(total=epochs, desc="Training BERT", unit="epoch") as pbar:
            # Store original train method
            original_train = trainer.train

            # Wrap train method to update progress bar
            def train_with_progress(*args_train, **kwargs_train):
                result = original_train(*args_train, **kwargs_train)
                pbar.update(epochs)
                return result

            trainer.train = train_with_progress
            trainer.train()
    else:
        # Silent training
        trainer.train()

    self.is_trained = True
predict(texts)

Predict labels for input text(s).

Parameters:

Name Type Description Default
texts str | list[str]

Single text string or list of text strings.

required

Returns:

Type Description
str | list[str]

Predicted label(s). If input is single string, returns single label.

str | list[str]

If input is list, returns list of labels.

Raises:

Type Description
TrainingError

If model is not trained yet.

Source code in src/novelentitymatcher/core/bert_classifier.py
def predict(self, texts: str | list[str]) -> str | list[str]:
    """Predict labels for input text(s).

    Args:
        texts: Single text string or list of text strings.

    Returns:
        Predicted label(s). If input is single string, returns single label.
        If input is list, returns list of labels.

    Raises:
        TrainingError: If model is not trained yet.
    """
    if not self.is_trained or self.model is None or self.tokenizer is None:
        raise TrainingError(
            "Model not trained. Call train() first.",
            details={"model_name": self.model_name},
        )

    single_input = isinstance(texts, str)
    if single_input:
        texts_list: list[str] = [texts]  # type: ignore[list-item]
    else:
        texts_list = texts  # type: ignore[assignment]

    # Tokenize
    tokenizer = self.tokenizer
    inputs = tokenizer(
        texts_list,
        padding=True,
        truncation=True,
        max_length=self.max_length,
        return_tensors="pt",
    )

    # Move to same device as model
    device = next(self.model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Predict
    with torch.no_grad():
        outputs = self.model(**inputs)
        predictions = outputs.logits.argmax(dim=-1)

    # Convert to labels
    predicted_labels = [self.id2label[pred.item()] for pred in predictions]

    if single_input:
        return predicted_labels[0]
    return predicted_labels
predict_proba(text)

Get prediction probabilities for all labels.

Parameters:

Name Type Description Default
text str

Input text string.

required

Returns:

Type Description
ndarray

NumPy array of probabilities for each label, in same order as self.labels.

Raises:

Type Description
TrainingError

If model is not trained yet.

Source code in src/novelentitymatcher/core/bert_classifier.py
def predict_proba(self, text: str) -> np.ndarray:
    """Get prediction probabilities for all labels.

    Args:
        text: Input text string.

    Returns:
        NumPy array of probabilities for each label, in same order as self.labels.

    Raises:
        TrainingError: If model is not trained yet.
    """
    if not self.is_trained or self.model is None or self.tokenizer is None:
        raise TrainingError(
            "Model not trained. Call train() first.",
            details={"model_name": self.model_name},
        )

    # Tokenize
    inputs = self.tokenizer(
        [text],
        padding=True,
        truncation=True,
        max_length=self.max_length,
        return_tensors="pt",
    )

    # Move to same device as model
    device = next(self.model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Predict with probabilities
    with torch.no_grad():
        outputs = self.model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)

    return probs.cpu().numpy()[0]
save(path)

Save the trained model and tokenizer.

Parameters:

Name Type Description Default
path str

Directory path to save the model.

required

Raises:

Type Description
TrainingError

If model is not trained yet.

Source code in src/novelentitymatcher/core/bert_classifier.py
def save(self, path: str):
    """Save the trained model and tokenizer.

    Args:
        path: Directory path to save the model.

    Raises:
        TrainingError: If model is not trained yet.
    """
    if not self.is_trained or self.model is None or self.tokenizer is None:
        raise TrainingError(
            "Model not trained. Call train() first.",
            details={"model_name": self.model_name},
        )

    save_path = Path(path)
    save_path.mkdir(parents=True, exist_ok=True)

    self.model.save_pretrained(save_path)
    self.tokenizer.save_pretrained(save_path)

    # Save labels
    labels_path = save_path / "labels.txt"
    with open(labels_path, "w") as f:
        f.write("\n".join(self.labels))
load(path) classmethod

Load a trained BERTClassifier from disk.

Parameters:

Name Type Description Default
path str

Directory path containing the saved model.

required

Returns:

Type Description
BERTClassifier

Loaded BERTClassifier instance.

Source code in src/novelentitymatcher/core/bert_classifier.py
@classmethod
def load(cls, path: str) -> "BERTClassifier":
    """Load a trained BERTClassifier from disk.

    Args:
        path: Directory path containing the saved model.

    Returns:
        Loaded BERTClassifier instance.
    """
    load_path = Path(path)

    # Load labels
    labels_path = load_path / "labels.txt"
    if not labels_path.exists():
        raise FileNotFoundError(f"Labels file not found at {labels_path}")

    with open(labels_path) as f:
        labels = f.read().splitlines()

    # Initialize classifier
    clf = cls(labels=labels)

    # Load model and tokenizer
    clf.tokenizer = AutoTokenizer.from_pretrained(load_path)
    clf.model = AutoModelForSequenceClassification.from_pretrained(load_path)
    clf.is_trained = True

    return clf

Functions

novelentitymatcher.core.matching_strategy

Matching strategy pattern for Matcher mode selection.

Classes

StrategyConfig(threshold, model_name, training_mode, normalize=True) dataclass

Configuration for matching strategies.

Encapsulates threshold, model settings, and training mode that were previously managed in _EntityMatcher.

MatchingStrategy(matcher)

Bases: ABC

Abstract base class for matching strategies.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(self, matcher: MatcherFacade):
    self._matcher = matcher
Functions
match(texts, top_k=1, threshold_override=None, **kwargs) abstractmethod

Execute matching with this strategy.

Source code in src/novelentitymatcher/core/matching_strategy.py
@abstractmethod
def match(
    self,
    texts: TextInput,
    top_k: int = 1,
    threshold_override: float | None = None,
    **kwargs,
) -> Any:
    """Execute matching with this strategy."""
match_async(texts, top_k=1, threshold_override=None, **kwargs) abstractmethod async

Execute async matching with this strategy.

Source code in src/novelentitymatcher/core/matching_strategy.py
@abstractmethod
async def match_async(
    self,
    texts: TextInput,
    top_k: int = 1,
    threshold_override: float | None = None,
    **kwargs,
) -> Any:
    """Execute async matching with this strategy."""
build_index() abstractmethod

Build any required index for this strategy.

Source code in src/novelentitymatcher/core/matching_strategy.py
@abstractmethod
def build_index(self) -> None:
    """Build any required index for this strategy."""
get_reference_corpus() abstractmethod

Get reference corpus for this strategy.

Source code in src/novelentitymatcher/core/matching_strategy.py
@abstractmethod
def get_reference_corpus(self) -> dict:
    """Get reference corpus for this strategy."""

ZeroShotStrategy(matcher)

Bases: MatchingStrategy

Strategy for zero-shot (embedding-only) matching.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(self, matcher: MatcherFacade):
    self._matcher = matcher

HeadOnlyFullStrategy(matcher)

Bases: MatchingStrategy

Strategy for head-only and full training modes.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(self, matcher: MatcherFacade):
    self._matcher = matcher

BertStrategy(matcher)

Bases: MatchingStrategy

Strategy for BERT-based matching.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(self, matcher: MatcherFacade):
    self._matcher = matcher

HybridStrategy(matcher)

Bases: MatchingStrategy

Strategy for hybrid blocking + retrieval matching.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(self, matcher: MatcherFacade):
    self._matcher = matcher

MatcherFacade(embedding_matcher, entity_matcher, bert_matcher, hybrid_matcher, config)

Facade providing access to all matcher components for strategies.

Source code in src/novelentitymatcher/core/matching_strategy.py
def __init__(
    self,
    embedding_matcher: EmbeddingMatcher,
    entity_matcher: EntityMatcher,
    bert_matcher: BERTClassifier,
    hybrid_matcher: HybridMatcher,
    config: StrategyConfig,
):
    self.embedding_matcher = embedding_matcher
    self.entity_matcher = entity_matcher
    self.bert_matcher = bert_matcher
    self.hybrid_matcher = hybrid_matcher
    self.threshold = config.threshold
    self.model_name = config.model_name
    self._training_mode = config.training_mode
    self._config = config
Functions
get_strategy(mode=None)

Get strategy instance for the given or current mode.

Source code in src/novelentitymatcher/core/matching_strategy.py
def get_strategy(self, mode: str | None = None) -> MatchingStrategy:
    """Get strategy instance for the given or current mode."""
    effective_mode = mode or self._training_mode
    strategy_cls = get_strategy(effective_mode)
    return strategy_cls(self)

Functions

get_strategy(mode)

Get strategy class for the given mode.

Source code in src/novelentitymatcher/core/matching_strategy.py
def get_strategy(mode: str) -> type[MatchingStrategy]:
    """Get strategy class for the given mode."""
    if mode not in _STRATEGY_MAP:
        from ..exceptions import ModeError

        raise ModeError(f"Unknown mode: {mode}", invalid_mode=mode)
    return _STRATEGY_MAP[mode]

novelentitymatcher.core.hybrid

Hybrid matching pipeline with blocking, retrieval, and reranking.

Classes

HybridMatcher(entities, blocking_strategy=None, retriever_model='BAAI/bge-base-en-v1.5', reranker_model='BAAI/bge-reranker-v2-m3', normalize=True)

Three-stage waterfall pipeline for semantic entity matching.

Combines fast blocking, semantic retrieval, and precise reranking for accurate and efficient matching.

Pipeline Stages
  1. Blocking (BM25/TF-IDF/Fuzzy) - Fast lexical filtering
  2. Bi-Encoder Retrieval - Semantic similarity search
  3. Cross-Encoder Reranking - Precise cross-attention scoring
Example

from novelentitymatcher import HybridMatcher from novelentitymatcher.core.blocking import BM25Blocking

matcher = HybridMatcher( ... entities=products, ... blocking_strategy=BM25Blocking(), ... retriever_model="bge-base", ... reranker_model="bge-m3" ... )

results = matcher.match( ... "iPhone 15 case", ... blocking_top_k=1000, ... retrieval_top_k=50, ... final_top_k=5 ... )

Parameters:

Name Type Description Default
entities list[dict[str, Any]]

List of entity dictionaries

required
blocking_strategy BlockingStrategy | None

Blocking strategy (defaults to NoOpBlocking)

None
retriever_model str

Model name for bi-encoder retrieval

'BAAI/bge-base-en-v1.5'
reranker_model str

Model name for cross-encoder reranking

'BAAI/bge-reranker-v2-m3'
normalize bool

Whether to normalize text (lowercase, remove accents, etc.)

True
Source code in src/novelentitymatcher/core/hybrid.py
def __init__(
    self,
    entities: list[dict[str, Any]],
    blocking_strategy: BlockingStrategy | None = None,
    retriever_model: str = "BAAI/bge-base-en-v1.5",
    reranker_model: str = "BAAI/bge-reranker-v2-m3",
    normalize: bool = True,
):
    """
    Initialize the hybrid matcher.

    Args:
        entities: List of entity dictionaries
        blocking_strategy: Blocking strategy (defaults to NoOpBlocking)
        retriever_model: Model name for bi-encoder retrieval
        reranker_model: Model name for cross-encoder reranking
        normalize: Whether to normalize text (lowercase, remove accents, etc.)
    """
    # Stage 1: Blocking
    self.blocker = blocking_strategy or NoOpBlocking()

    # Stage 2: Bi-Encoder Retrieval
    self.retriever = EmbeddingMatcher(
        entities=entities,
        model_name=retriever_model,
        normalize=normalize,
    )
    self.retriever.build_index()

    # Stage 3: Cross-Encoder Reranking
    self.reranker = CrossEncoderReranker(model=reranker_model)
Functions
match(query, blocking_top_k=1000, retrieval_top_k=50, final_top_k=5)

Match query using three-stage waterfall pipeline.

Parameters:

Name Type Description Default
query str

Search query

required
blocking_top_k int

Number of candidates after blocking stage

1000
retrieval_top_k int

Number of candidates after retrieval stage

50
final_top_k int

Number of final results after reranking

5

Returns:

Type Description
list[dict[str, Any]]

List of matched entities with scores (bi-encoder and cross-encoder)

Source code in src/novelentitymatcher/core/hybrid.py
def match(
    self,
    query: str,
    blocking_top_k: int = 1000,
    retrieval_top_k: int = 50,
    final_top_k: int = 5,
) -> list[dict[str, Any]]:
    """
    Match query using three-stage waterfall pipeline.

    Args:
        query: Search query
        blocking_top_k: Number of candidates after blocking stage
        retrieval_top_k: Number of candidates after retrieval stage
        final_top_k: Number of final results after reranking

    Returns:
        List of matched entities with scores (bi-encoder and cross-encoder)
    """
    # Stage 1: Blocking - Fast lexical filtering
    candidates = self.blocker.block(
        query, self.retriever.entities, top_k=blocking_top_k
    )

    # Early exit if no candidates from blocking
    if not candidates:
        return []

    # Stage 2: Bi-Encoder Retrieval - Semantic similarity
    retrieved = self.retriever.match(
        query,
        candidates=candidates,
        top_k=retrieval_top_k,
    )

    # Ensure retrieved is a list (handle single result case)
    if retrieved is None:
        return []
    if not isinstance(retrieved, list):
        retrieved = [retrieved]

    # Filter out None results
    retrieved = [r for r in retrieved if r is not None]

    # Stage 3: Cross-Encoder Reranking - Precise scoring
    if not retrieved:
        return []

    final = self.reranker.rerank(query, retrieved, top_k=final_top_k)

    return final
match_bulk(queries, blocking_top_k=1000, retrieval_top_k=50, final_top_k=5, n_jobs=-1, chunk_size=None)

Batch matching for multiple queries.

Batches bi-encoder encoding across all queries (single model.encode call instead of one per query), then computes per-query similarity against blocked candidates.

Parameters:

Name Type Description Default
queries list[str]

List of search queries

required
blocking_top_k int

Number of candidates after blocking stage

1000
retrieval_top_k int

Number of candidates after retrieval stage

50
final_top_k int

Number of final results after reranking

5
n_jobs int

Ignored (kept for backwards compatibility).

-1
chunk_size int | None

Ignored (kept for backwards compatibility).

None

Returns:

Type Description
list[list[dict[str, Any]]]

List of matched entity lists (one per query)

Source code in src/novelentitymatcher/core/hybrid.py
def match_bulk(
    self,
    queries: list[str],
    blocking_top_k: int = 1000,
    retrieval_top_k: int = 50,
    final_top_k: int = 5,
    n_jobs: int = -1,
    chunk_size: int | None = None,
) -> list[list[dict[str, Any]]]:
    """
    Batch matching for multiple queries.

    Batches bi-encoder encoding across all queries (single model.encode call
    instead of one per query), then computes per-query similarity against
    blocked candidates.

    Args:
        queries: List of search queries
        blocking_top_k: Number of candidates after blocking stage
        retrieval_top_k: Number of candidates after retrieval stage
        final_top_k: Number of final results after reranking
        n_jobs: Ignored (kept for backwards compatibility).
        chunk_size: Ignored (kept for backwards compatibility).

    Returns:
        List of matched entity lists (one per query)
    """
    if not queries:
        return []

    # Stage 1: Blocking - per-query lexical filtering
    all_candidates: list[list[dict[str, Any]]] = []
    for query in queries:
        candidates = self.blocker.block(
            query, self.retriever.entities, top_k=blocking_top_k
        )
        all_candidates.append(candidates or [])

    # Stage 2: Bi-Encoder Retrieval - batched encoding
    query_embeddings = self.retriever.model.encode(queries)  # type: ignore[union-attr]
    if isinstance(query_embeddings, list):
        query_embeddings = np.array(query_embeddings)

    entity_lookup = {e["id"]: e for e in self.retriever.entities}
    all_retrieved: list[list[dict[str, Any]]] = []

    for i in range(len(queries)):
        candidates = all_candidates[i]
        if not candidates:
            all_retrieved.append([])
            continue

        candidate_ids = {c["id"] for c in candidates}
        candidate_indices = [
            j
            for j, eid in enumerate(self.retriever.entity_ids)
            if eid in candidate_ids
        ]
        if not candidate_indices:
            all_retrieved.append([])
            continue

        candidate_embeddings = self.retriever.embeddings[candidate_indices]  # type: ignore[index]
        query_emb = query_embeddings[i : i + 1]
        similarities = cosine_similarity(query_emb, candidate_embeddings)[0]

        sorted_indices = np.argsort(similarities)[::-1]
        seen_ids: set[str] = set()
        retrieved: list[dict[str, Any]] = []
        for idx in sorted_indices:
            score = similarities[idx]
            if score < self.retriever.threshold:
                continue
            entity_id = self.retriever.entity_ids[candidate_indices[idx]]
            if entity_id in seen_ids:
                continue
            seen_ids.add(entity_id)
            entity = entity_lookup.get(entity_id, {})
            retrieved.append(
                {
                    "id": entity_id,
                    "score": float(score),
                    "text": entity.get(
                        "name",
                        self.retriever.entity_texts[candidate_indices[idx]],
                    ),
                }
            )
            if len(retrieved) >= retrieval_top_k:
                break

        all_retrieved.append(retrieved)

    # Stage 3: Cross-Encoder Reranking - per-query
    results: list[list[dict[str, Any]]] = []
    for query, retrieved in zip(queries, all_retrieved, strict=True):
        if not retrieved:
            results.append([])
        else:
            results.append(
                self.reranker.rerank(query, retrieved, top_k=final_top_k)
            )

    return results

novelentitymatcher.core.matcher_engines

Classes

novelentitymatcher.core.matcher_entity

Classes

Functions

novelentitymatcher.core.matcher_components

Classes

MatcherComponentFactory(owner)

Lazy matcher-component construction behind the public Matcher facade.

Source code in src/novelentitymatcher/core/matcher_components.py
def __init__(self, owner: Matcher):
    self._owner = owner
    self._embedding_matcher: Any = None
    self._entity_matcher: Any = None
    self._bert_matcher: Any = None
    self._hybrid_matcher: Any = None

novelentitymatcher.core.matcher_runtime

Classes

MatcherRuntimeState(requested_model, model_name, training_model_name, bert_model_name, threshold, training_mode, detected_mode=None, has_training_data=False) dataclass

Centralized matcher configuration and mutable runtime state.

Functions

novelentitymatcher.core.matcher_shared

Classes

Functions

extract_top_prediction_metadata(match_results, single_input)

Normalize matcher output into top-1 predictions and confidences.

Novel class detection only needs the best prediction per input. This keeps a stable shape even when the underlying matcher returns dicts, lists, strings, or None values.

Source code in src/novelentitymatcher/core/matcher_shared.py
def extract_top_prediction_metadata(
    match_results: Any, single_input: bool
) -> tuple[list[str], np.ndarray]:
    """
    Normalize matcher output into top-1 predictions and confidences.

    Novel class detection only needs the best prediction per input. This keeps a
    stable shape even when the underlying matcher returns dicts, lists, strings,
    or ``None`` values.
    """

    def _from_result(result: Any) -> tuple[str, float]:
        if result is None:
            return "unknown", 0.0
        if isinstance(result, dict):
            return result.get("id", "unknown"), float(result.get("score", 0.0))
        if isinstance(result, list):
            if not result:
                return "unknown", 0.0
            first = result[0]
            if isinstance(first, dict):
                return first.get("id", "unknown"), float(first.get("score", 0.0))
            if first is None:
                return "unknown", 0.0
            return str(first), 1.0
        return str(result), 1.0

    if single_input:
        prediction, confidence = _from_result(match_results)
        return [prediction], np.array([confidence], dtype=float)

    predictions: list[str] = []
    confidences: list[float] = []
    for result in match_results:
        prediction, confidence = _from_result(result)
        predictions.append(prediction)
        confidences.append(confidence)

    return predictions, np.array(confidences, dtype=float)

novelentitymatcher.core.async_utils

Classes

AsyncExecutor(max_workers=None)

Manages async execution of sync operations.

Runs CPU-bound or blocking sync operations in a thread pool, allowing async code to proceed without blocking the event loop.

Parameters:

Name Type Description Default
max_workers int | None

Maximum number of worker threads. Defaults to CPU_COUNT * 2, capped at 32 for I/O bound workloads.

None
Source code in src/novelentitymatcher/core/async_utils.py
def __init__(self, max_workers: int | None = None):
    """
    Initialize the async executor.

    Args:
        max_workers: Maximum number of worker threads. Defaults to CPU_COUNT * 2,
            capped at 32 for I/O bound workloads.
    """
    if max_workers is None:
        max_workers = min(32, (os.cpu_count() or 1) * 2)
    self._executor: ThreadPoolExecutor | None = ThreadPoolExecutor(
        max_workers=max_workers
    )
    self._is_shutdown = False
Functions
run_in_thread(func, *args, **kwargs) async

Run a sync function in a thread pool.

Parameters:

Name Type Description Default
func Callable

Synchronous function to execute

required
*args

Positional arguments to pass to func

()
**kwargs

Keyword arguments to pass to func

{}

Returns:

Type Description
Any

The return value of func

Source code in src/novelentitymatcher/core/async_utils.py
async def run_in_thread(self, func: Callable, *args, **kwargs) -> Any:
    """
    Run a sync function in a thread pool.

    Args:
        func: Synchronous function to execute
        *args: Positional arguments to pass to func
        **kwargs: Keyword arguments to pass to func

    Returns:
        The return value of func
    """
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(
        self._require_executor(), functools.partial(func, *args, **kwargs)
    )
run_in_thread_batch(func, items, batch_size=32) async

Run sync function on batches concurrently.

Splits items into batches and runs func on each batch in parallel, then flattens the results.

Parameters:

Name Type Description Default
func Callable

Function that takes a list and returns a list

required
items list[Any]

Items to process in batches

required
batch_size int

Size of each batch

32

Returns:

Type Description
list[Any]

Flattened list of results from all batches

Source code in src/novelentitymatcher/core/async_utils.py
async def run_in_thread_batch(
    self, func: Callable, items: list[Any], batch_size: int = 32
) -> list[Any]:
    """
    Run sync function on batches concurrently.

    Splits items into batches and runs func on each batch in parallel,
    then flattens the results.

    Args:
        func: Function that takes a list and returns a list
        items: Items to process in batches
        batch_size: Size of each batch

    Returns:
        Flattened list of results from all batches
    """
    batches = [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
    tasks = [self.run_in_thread(func, batch) for batch in batches]
    results = await asyncio.gather(*tasks)
    return [item for batch in results for item in batch]
shutdown()

Clean up resources by shutting down the thread pool. Idempotent.

Source code in src/novelentitymatcher/core/async_utils.py
def shutdown(self):
    """Clean up resources by shutting down the thread pool. Idempotent."""
    if self._executor is not None:
        self._executor.shutdown(wait=True)
        self._executor = None
    self._is_shutdown = True