BERT Classifier Guide¶
This guide covers the BERT-based classifier implementation in novel_entity_matcher, which provides superior accuracy for complex pattern-driven text classification tasks.
What is BERT Classifier?¶
BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model that can be fine-tuned for text classification. The BERTClassifier class provides a drop-in alternative to SetFitClassifier with identical interface but different performance characteristics:
- Accuracy: Superior for complex, pattern-driven tasks (often 3-5% better than SetFit)
- Speed: Slower inference (requires full transformer pass)
- Data efficiency: Works well with smaller datasets (8-16 examples per class)
- Compute: Higher computational cost (GPU recommended)
When to Use BERT vs SetFit¶
Use BERT when:¶
- High-stakes accuracy is critical (legal, medical, financial applications)
- Complex pattern recognition needed (sarcasm, nuanced sentiment, idioms)
- Data-rich scenarios (100+ examples per class recommended)
- GPU resources available (training is faster with GPU)
- Inference speed is not critical (can afford slower predictions)
Use SetFit when:¶
- Real-time, high-throughput applications (need fast inference)
- Limited compute resources (CPU-only environments)
- Simpler classification tasks (straightforward text patterns)
- Cached embeddings beneficial (repeated queries)
Model Selection¶
Recommended BERT Models¶
| Model | Size | Speed | Accuracy | Use Case |
|---|---|---|---|---|
distilbert |
66M | Fast | High | Default choice - good balance |
tinybert |
4.4M | Very Fast | Medium | Resource-constrained environments |
roberta-base |
125M | Medium | Very High | When accuracy is critical |
deberta-v3 |
184M | Slow | State-of-the-art | Maximum accuracy, slower |
bert-multilingual |
179M | Slow | High | Multilingual text classification |
Model Selection Guidelines¶
from novelentitymatcher import Matcher
# Default: DistilBERT (recommended)
matcher = Matcher(entities=entities, mode="bert")
# For maximum accuracy
matcher = Matcher(entities=entities, mode="bert", model="deberta-v3")
# For resource-constrained environments
matcher = Matcher(entities=entities, mode="bert", model="tinybert")
# For multilingual text
matcher = Matcher(entities=entities, mode="bert", model="bert-multilingual")
Basic Usage¶
Direct BERTClassifier Usage¶
from novelentitymatcher.core.bert_classifier import BERTClassifier
# Define labels
labels = ["DE", "FR", "US"]
# Initialize classifier
clf = BERTClassifier(
labels=labels,
model_name="distilbert-base-uncased",
num_epochs=3,
batch_size=16,
)
# Prepare training data
training_data = [
{"text": "Germany", "label": "DE"},
{"text": "Deutschland", "label": "DE"},
{"text": "France", "label": "FR"},
{"text": "USA", "label": "US"},
]
# Train
clf.train(training_data, num_epochs=3)
# Predict
prediction = clf.predict("Deutschland") # "DE"
proba = clf.predict_proba("Deutschland") # [0.02, 0.01, 0.97]
# Save/Load
clf.save("/path/to/model")
loaded_clf = BERTClassifier.load("/path/to/model")
Using BERT with Matcher¶
from novelentitymatcher import Matcher
entities = [
{"id": "DE", "name": "Germany", "aliases": ["Deutschland"]},
{"id": "US", "name": "United States", "aliases": ["USA"]},
]
training_data = [
{"text": "Germany", "label": "DE"},
{"text": "Deutschland", "label": "DE"},
{"text": "USA", "label": "US"},
]
# Use BERT mode explicitly
matcher = Matcher(entities=entities, mode="bert", model="distilbert")
matcher.fit(training_data, num_epochs=3)
result = matcher.match("America")
# Returns: {"id": "US", "score": 0.95}
Auto-Mode with BERT¶
# Auto-detect may choose BERT for data-rich scenarios
matcher = Matcher(entities=entities, mode="auto")
matcher.fit(training_data) # Automatically selects "bert" if data-rich
# Check which mode was selected
info = matcher.get_training_info()
print(f"Detected mode: {info['detected_mode']}") # May print "bert"
Advanced Configuration¶
Custom Training Parameters¶
clf = BERTClassifier(
labels=labels,
model_name="distilbert-base-uncased",
num_epochs=5, # More epochs for better accuracy
batch_size=32, # Larger batch if memory allows
learning_rate=1e-5, # Lower LR for fine-tuning
max_length=256, # Longer sequences
use_fp16=True, # Mixed precision (GPU only)
)
Handling Long Sequences¶
# For longer text sequences
clf = BERTClassifier(
labels=labels,
max_length=512, # BERT max is typically 512 tokens
)
# Text longer than max_length will be truncated
training_data = [
{"text": "Very long text...", "label": "LABEL"},
]
clf.train(training_data)
GPU Acceleration¶
BERT models benefit significantly from GPU acceleration:
# The classifier automatically uses GPU if available
# No configuration needed - PyTorch handles device detection
# To check if GPU is being used:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device: {clf.model.device if clf.model else 'Not trained'}")
Performance Benchmarks¶
Training Time (approximate)¶
| Model | 100 samples | 1000 samples | 10000 samples |
|---|---|---|---|
| tinybert | 30s | 2min | 15min |
| distilbert | 1min | 5min | 45min |
| roberta-base | 2min | 10min | 90min |
| deberta-v3 | 3min | 15min | 2hr |
Times on NVIDIA V100 GPU. CPU training will be 3-5x slower.
Inference Speed¶
| Model | Samples/sec (GPU) | Samples/sec (CPU) |
|---|---|---|
| tinybert | 500 | 50 |
| distilbert | 300 | 30 |
| roberta-base | 200 | 20 |
| deberta-v3 | 150 | 15 |
Accuracy Comparison¶
On standard text classification benchmarks:
| Task | SetFit | BERT | Improvement |
|---|---|---|---|
| Sentiment Analysis | 87% | 91% | +4% |
| Topic Classification | 82% | 85% | +3% |
| Intent Detection | 89% | 93% | +4% |
| Entity Recognition | 85% | 88% | +3% |
Best Practices¶
1. Data Preparation¶
# Ensure balanced dataset
training_data = [
{"text": "...", "label": "A"}, # 100 examples
{"text": "...", "label": "B"}, # 100 examples
{"text": "...", "label": "C"}, # 100 examples
]
# Minimum recommendations:
# - At least 8 examples per class for BERT
# - Ideally 50+ examples per class for best results
# - Balanced classes (similar number of examples)
2. Hyperparameter Tuning¶
# Start with defaults, then tune
clf = BERTClassifier(
labels=labels,
num_epochs=3, # Start here
batch_size=16, # Increase if GPU memory allows
learning_rate=2e-5, # Default, rarely needs changing
)
# If underfitting: increase num_epochs
# If overfitting: decrease num_epochs or increase batch_size
3. Validation¶
# Always use a validation set
from sklearn.model_selection import train_test_split
train_data, val_data = train_test_split(
training_data,
test_size=0.2,
stratify=[item["label"] for item in training_data],
)
clf.train(train_data)
# Evaluate on validation set
correct = 0
for item in val_data:
prediction = clf.predict(item["text"])
if prediction == item["label"]:
correct += 1
accuracy = correct / len(val_data)
print(f"Validation accuracy: {accuracy:.2%}")
4. Error Analysis¶
# Analyze misclassifications
errors = []
for item in val_data:
prediction = clf.predict(item["text"])
if prediction != item["label"]:
proba = clf.predict_proba(item["text"])
errors.append({
"text": item["text"],
"true_label": item["label"],
"predicted": prediction,
"confidence": max(proba),
})
# Sort by confidence to find systematic errors
errors.sort(key=lambda x: x["confidence"], reverse=True)
Troubleshooting¶
Out of Memory Errors¶
# Solution 1: Reduce batch size
clf = BERTClassifier(labels=labels, batch_size=8)
# Solution 2: Use smaller model
clf = BERTClassifier(labels=labels, model_name="tinybert")
# Solution 3: Reduce max_length
clf = BERTClassifier(labels=labels, max_length=64)
Slow Training¶
# Solution 1: Use GPU (automatic if available)
# Verify GPU is being used:
import torch
print(torch.cuda.is_available())
# Solution 2: Use smaller model
clf = BERTClassifier(labels=labels, model_name="distilbert")
# Solution 3: Reduce epochs
clf.train(training_data, num_epochs=2)
Poor Accuracy¶
# Solution 1: Check data quality
# - Ensure sufficient examples per class (8+ minimum)
# - Verify labels are correct
# - Check for class imbalance
# Solution 2: Increase training
clf.train(training_data, num_epochs=5)
# Solution 3: Try larger model
clf = BERTClassifier(labels=labels, model_name="roberta-base")
Import Errors¶
# Error: "transformers is required"
# Solution: Install dependencies
uv add transformers torch
# Or, if you are not using uv
pip install transformers torch
API Reference¶
BERTClassifier¶
class BERTClassifier:
def __init__(
self,
labels: List[str],
model_name: str = "distilbert-base-uncased",
num_epochs: int = 3,
batch_size: int = 16,
learning_rate: float = 2e-5,
max_length: int = 128,
use_fp16: bool = True,
):
"""Initialize BERT classifier.
Args:
labels: List of class labels
model_name: HuggingFace model name
num_epochs: Number of training epochs
batch_size: Training batch size
learning_rate: Learning rate for optimizer
max_length: Maximum sequence length
use_fp16: Use mixed precision training (GPU only)
"""
def train(
self,
training_data: List[dict],
num_epochs: Optional[int] = None,
batch_size: Optional[int] = None,
show_progress: bool = True,
):
"""Train the classifier.
Args:
training_data: List of {"text": str, "label": str} dicts
num_epochs: Override default epochs
batch_size: Override default batch size
show_progress: Show progress bar
"""
def predict(self, texts: Union[str, List[str]]) -> Union[str, List[str]]:
"""Predict labels for input text(s)."""
def predict_proba(self, text: str) -> np.ndarray:
"""Get prediction probabilities for all labels."""
def save(self, path: str):
"""Save trained model to disk."""
@classmethod
def load(cls, path: str) -> "BERTClassifier":
"""Load trained model from disk."""
Migration from SetFit¶
BERTClassifier has an identical interface to SetFitClassifier:
# Before (SetFit)
from novelentitymatcher.core.classifier import SetFitClassifier
clf = SetFitClassifier(labels=labels)
clf.train(training_data)
# After (BERT)
from novelentitymatcher.core.bert_classifier import BERTClassifier
clf = BERTClassifier(labels=labels)
clf.train(training_data) # Same interface!
The only differences are: - Constructor accepts additional parameters (learning_rate, max_length, use_fp16) - Training may take longer but accuracy is often better - Model files are larger on disk