Hmm… Apparently, it’s possible using the custom batch sampler… but I’m not entirely sure if that’s true.
just in case, @tomaarsen
Custom batches in sentence-transformers for MultipleNegativesRankingLoss
Short answer
Use a normal datasets.Dataset, but do not rely on the default batching behavior of SentenceTransformerTrainer.
Instead:
- Flatten your pre-built batches into a single
datasets.Dataset.
- Keep the rows ordered so that rows
0..255 are your first curated batch, rows 256..511 are your second curated batch, and so on.
- Pass a custom batch sampler through
SentenceTransformerTrainingArguments(batch_sampler=...).
- Make the custom sampler yield exactly the row-index groups you want.
The key point is:
train_dataset = storage format
batch_sampler = batching policy
loss = uses the resulting batch as the contrastive pool
So your train_dataset=??? should be a flattened datasets.Dataset, and the batch-preserving logic should live in args.batch_sampler.
Why this is the right abstraction
Your understanding of MultipleNegativesRankingLoss is correct.
MultipleNegativesRankingLoss is an in-batch contrastive loss. For each anchor in a batch, the matching positive should be closer than the other candidate positives or documents in the same batch.
For example, suppose one minibatch contains:
[
("Blue Street, 1, New York", "Blue Street 1 - New York"),
("Blue Street, 11, New York", "Blue Street 11 - New York"),
]
Then, for the first row, the model is trained to make:
"Blue Street, 1, New York"
closer to:
"Blue Street 1 - New York"
than to:
"Blue Street 11 - New York"
That is exactly the behavior you want.
With address matching, the hard part is not only learning that two variants of the same address are close. The hard part is learning that extremely similar-looking addresses may still be different real-world entities:
Blue Street 1, New York == Blue St. 1, NYC
Blue Street 1, New York != Blue Street 11, New York
Blue Street 1, New York != Blue Street 1, Newark
Blue Street 1 Apt 2 != Blue Street 1 Apt 20
So your curated batch structure is not incidental. It is part of the supervision.
Why CosineSimilarityLoss is usually weaker here
CosineSimilarityLoss is pairwise. It sees one pair and a target score.
That works well when your labels are naturally pairwise:
pair A similarity = 0.95
pair B similarity = 0.10
pair C similarity = 0.60
But your real task is closer to ranking:
Given:
"Blue Street, 1, New York"
Rank this highest:
"Blue Street 1 - New York"
Rank these lower:
"Blue Street 11 - New York"
"Blue Street 1 - Newark"
"Blue Avenue 1 - New York"
"Blue Street 1 Apt 2 - New York"
That is why MultipleNegativesRankingLoss is a better fit. It turns each minibatch into a local retrieval problem.
In this sense, your task is not just sentence similarity. It is closer to:
postal-address entity resolution
+
dense retrieval
+
hard-negative metric learning
Why not pass a DataLoader?
The newer SentenceTransformerTrainer API expects a dataset, not a user-supplied PyTorch DataLoader.
That does not mean you cannot control batches.
The control point is SentenceTransformerTrainingArguments.batch_sampler, documented in the samplers reference. The sampler docs explain that a custom batch sampler can be supplied by subclassing DefaultBatchSampler or by passing a function that returns a DefaultBatchSampler instance.
So the correct structure is:
trainer = SentenceTransformerTrainer(
model=model,
args=args, # contains batch_sampler
train_dataset=train_dataset,
loss=loss_fn,
)
not:
trainer = SentenceTransformerTrainer(
model=model,
train_dataloader=my_dataloader, # not the intended API
)
Step 1: flatten your pre-built batches
Assume your current data looks conceptually like this:
batches = [
[
(batch1_anchor1, batch1_positive1),
(batch1_anchor2, batch1_positive2),
...
],
[
(batch2_anchor1, batch2_positive1),
(batch2_anchor2, batch2_positive2),
...
],
]
Flatten it while preserving order:
from datasets import Dataset
BATCH_SIZE = 256
flat_anchors: list[str] = []
flat_positives: list[str] = []
for batch_idx, batch in enumerate(batches):
if len(batch) != BATCH_SIZE:
raise ValueError(
f"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}."
)
for anchor, positive in batch:
flat_anchors.append(anchor)
flat_positives.append(positive)
train_dataset = Dataset.from_dict(
{
"anchor": flat_anchors,
"positive": flat_positives,
}
)
# Keep column order explicit.
train_dataset = train_dataset.select_columns(["anchor", "positive"])
Now the dataset rows have this structure:
rows 0..255 = curated batch 0
rows 256..511 = curated batch 1
rows 512..767 = curated batch 2
...
The dataset itself is flat, but your precomputed batch structure is preserved by row position.
Important: do not pass metadata columns directly to the trainer
Sentence Transformers training datasets are column-order-sensitive. In the training overview, non-label columns are treated as model inputs.
So this is safe:
train_dataset = train_dataset.select_columns(["anchor", "positive"])
This is risky if passed directly to the trainer:
Dataset.from_dict(
{
"anchor": anchors,
"positive": positives,
"batch_id": batch_ids,
"canonical_address_id": canonical_ids,
}
)
because batch_id and canonical_address_id are metadata, not text inputs.
Keep metadata during preprocessing and validation, but remove it before training:
train_dataset = full_dataset.select_columns(["anchor", "positive"])
Step 2: define a custom batch sampler
This sampler yields contiguous blocks of indices:
[0, 1, 2, ..., 255]
[256, 257, 258, ..., 511]
[512, 513, 514, ..., 767]
...
It can shuffle the order of whole batches between epochs, but it never mixes examples across your curated batches.
from collections.abc import Iterator
import torch
from datasets import Dataset
from sentence_transformers.sampler import DefaultBatchSampler
class ExactPreBatchedSampler(DefaultBatchSampler):
"""
Preserves precomputed contiguous batches.
Assumption:
Rows 0..255 are curated batch 0
Rows 256..511 are curated batch 1
Rows 512..767 are curated batch 2
...
The sampler may shuffle the order of whole batches, but it never mixes
rows from different precomputed batches.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
seed: int = 0,
shuffle_batches: bool = True,
) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
seed=seed,
)
self.dataset = dataset
self.shuffle_batches = shuffle_batches
if self.batch_size <= 0:
raise ValueError(f"batch_size must be positive, got {self.batch_size}.")
if len(self.dataset) < self.batch_size:
raise ValueError(
f"Dataset has {len(self.dataset)} rows, "
f"but batch_size={self.batch_size}."
)
def __iter__(self) -> Iterator[list[int]]:
# DefaultBatchSampler provides epoch handling via SetEpochMixin.
if self.generator is not None and self.seed is not None:
self.generator.manual_seed(self.seed + self.epoch)
n_full_batches = len(self.dataset) // self.batch_size
remainder_start = n_full_batches * self.batch_size
batch_ids = torch.arange(n_full_batches)
if self.shuffle_batches:
batch_ids = batch_ids[
torch.randperm(n_full_batches, generator=self.generator)
]
for batch_id in batch_ids.tolist():
start = batch_id * self.batch_size
end = start + self.batch_size
yield list(range(start, end))
if not self.drop_last and remainder_start < len(self.dataset):
yield list(range(remainder_start, len(self.dataset)))
def __len__(self) -> int:
n_full_batches = len(self.dataset) // self.batch_size
has_remainder = len(self.dataset) % self.batch_size != 0
return n_full_batches + int(has_remainder and not self.drop_last)
Step 3: pass the sampler through SentenceTransformerTrainingArguments
Use a small factory function. This is convenient because the trainer constructs the sampler internally and supplies arguments such as dataset, batch_size, drop_last, generator, and seed.
def exact_prebatched_sampler_factory(
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
seed: int = 0,
):
if batch_size != BATCH_SIZE:
raise ValueError(
f"Expected batch_size={BATCH_SIZE}, got {batch_size}. "
"Use per_device_train_batch_size=256."
)
return ExactPreBatchedSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
seed=seed,
shuffle_batches=True,
)
Then configure the trainer:
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
losses,
)
BATCH_SIZE = 256
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=(
"query_to_doc",
"query_to_query",
"doc_to_query",
"doc_to_doc",
),
partition_mode="joint",
)
args = SentenceTransformerTrainingArguments(
output_dir="models/address-mpnet-mnrl",
# Must match your curated batch size.
per_device_train_batch_size=BATCH_SIZE,
# Usually safest if all curated batches are exactly size 256.
dataloader_drop_last=True,
# Critical part: preserve your precomputed batches.
batch_sampler=exact_prebatched_sampler_factory,
# Usual training settings. Tune these for your dataset.
num_train_epochs=1,
learning_rate=2e-5,
warmup_ratio=0.1,
# Use bf16 if your hardware supports it. Otherwise use fp16 or fp32.
bf16=True,
fp16=False,
logging_steps=50,
save_steps=500,
save_total_limit=2,
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss_fn,
)
trainer.train()
That is the core solution.
Complete minimal example
from collections.abc import Iterator
import torch
from datasets import Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
losses,
)
from sentence_transformers.sampler import DefaultBatchSampler
BATCH_SIZE = 256
class ExactPreBatchedSampler(DefaultBatchSampler):
def __init__(
self,
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
seed: int = 0,
shuffle_batches: bool = True,
) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
seed=seed,
)
self.dataset = dataset
self.shuffle_batches = shuffle_batches
def __iter__(self) -> Iterator[list[int]]:
if self.generator is not None and self.seed is not None:
self.generator.manual_seed(self.seed + self.epoch)
n_full_batches = len(self.dataset) // self.batch_size
remainder_start = n_full_batches * self.batch_size
batch_ids = torch.arange(n_full_batches)
if self.shuffle_batches:
batch_ids = batch_ids[
torch.randperm(n_full_batches, generator=self.generator)
]
for batch_id in batch_ids.tolist():
start = batch_id * self.batch_size
end = start + self.batch_size
yield list(range(start, end))
if not self.drop_last and remainder_start < len(self.dataset):
yield list(range(remainder_start, len(self.dataset)))
def __len__(self) -> int:
n_full_batches = len(self.dataset) // self.batch_size
has_remainder = len(self.dataset) % self.batch_size != 0
return n_full_batches + int(has_remainder and not self.drop_last)
def exact_prebatched_sampler_factory(
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] | None = None,
generator: torch.Generator | None = None,
seed: int = 0,
):
if batch_size != BATCH_SIZE:
raise ValueError(
f"Expected batch_size={BATCH_SIZE}, got {batch_size}. "
"Use per_device_train_batch_size=256."
)
return ExactPreBatchedSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
seed=seed,
shuffle_batches=True,
)
flat_anchors: list[str] = []
flat_positives: list[str] = []
for batch_idx, batch in enumerate(batches):
if len(batch) != BATCH_SIZE:
raise ValueError(
f"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}."
)
for anchor, positive in batch:
flat_anchors.append(anchor)
flat_positives.append(positive)
train_dataset = Dataset.from_dict(
{
"anchor": flat_anchors,
"positive": flat_positives,
}
).select_columns(["anchor", "positive"])
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=(
"query_to_doc",
"query_to_query",
"doc_to_query",
"doc_to_doc",
),
partition_mode="joint",
)
args = SentenceTransformerTrainingArguments(
output_dir="models/address-mpnet-mnrl",
per_device_train_batch_size=BATCH_SIZE,
dataloader_drop_last=True,
batch_sampler=exact_prebatched_sampler_factory,
num_train_epochs=1,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
fp16=False,
logging_steps=50,
save_steps=500,
save_total_limit=2,
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss_fn,
)
trainer.train()
Verify the sampler before training
Before launching a long training job, inspect a few batches.
sampler = exact_prebatched_sampler_factory(
dataset=train_dataset,
batch_size=BATCH_SIZE,
drop_last=True,
generator=torch.Generator().manual_seed(42),
seed=42,
)
for batch_number, indices in zip(range(5), sampler):
print(
"batch_number:",
batch_number,
"first_index:",
indices[0],
"last_index:",
indices[-1],
"size:",
len(indices),
)
print("first anchor:", train_dataset[indices[0]]["anchor"])
print("first positive:", train_dataset[indices[0]]["positive"])
print()
Expected shape:
batch_number: 0 first_index: 512 last_index: 767 size: 256
batch_number: 1 first_index: 0 last_index: 255 size: 256
batch_number: 2 first_index: 256 last_index: 511 size: 256
The order may differ because whole batches are shuffled, but each yielded batch should still be a contiguous block of 256 rows.
Address-specific warning: false negatives
The main risk with MultipleNegativesRankingLoss is false negatives.
In MNRL, other positives in the same batch are treated as negatives for the current anchor. That is useful only if those other positives are truly different addresses.
This is good:
anchor:
Blue Street, 1, New York
positive:
Blue Street 1 - New York
in-batch negative:
Blue Street 11 - New York
This is dangerous:
anchor:
Blue Street, 1, New York
positive:
Blue Street 1 - New York
in-batch negative from another row:
1 Blue St., NYC
because 1 Blue St., NYC may be the same real address.
So your batch builder should enforce a rule like:
No two different rows in the same MNRL batch may refer to the same canonical address.
Exact string deduplication is not enough. Prefer deduplication by one or more of:
canonical address ID
delivery point ID
authoritative normalized address
geocoder result ID
parcel/building/unit ID
high-confidence rooftop coordinate
Be careful with all four directions
You used:
directions=(
"query_to_doc",
"query_to_query",
"doc_to_query",
"doc_to_doc",
)
That is supported by the current MultipleNegativesRankingLoss API and can provide a stronger signal.
However, it also makes batch cleanliness stricter.
With only query_to_doc, the main requirement is:
anchor_i should not match positive_j for i != j
With query_to_query, you also need:
anchor_i should not be equivalent to anchor_j
With doc_to_doc, you also need:
positive_i should not be equivalent to positive_j
For all four directions, validate:
anchor_i is not equivalent to anchor_j
positive_i is not equivalent to positive_j
anchor_i is not equivalent to positive_j for i != j
positive_i is not equivalent to anchor_j for i != j
If your canonical-address checks are not strong yet, consider starting with a simpler loss configuration:
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=("query_to_doc", "doc_to_query"),
partition_mode="per_direction",
)
Then compare against the all-four-direction version.
Consider explicit hard negatives
Your current format is:
(anchor, positive)
That is valid for MNRL.
But if you already know specific hard negatives, you can also use:
(anchor, positive, negative_1, negative_2, negative_3)
MultipleNegativesRankingLoss supports pairs, triplets, and n-tuples:
train_dataset = Dataset.from_dict(
{
"anchor": [
"Blue Street, 1, New York",
"Blue Street, 11, New York",
],
"positive": [
"Blue Street 1 - New York",
"Blue Street 11 - New York",
],
"negative_1": [
"Blue Street 11 - New York",
"Blue Street 1 - New York",
],
"negative_2": [
"Blue Street 1 - Newark",
"Blue Street 11 - Newark",
],
}
).select_columns(["anchor", "positive", "negative_1", "negative_2"])
This gives the loss both:
explicit hard negatives attached to each row
+
in-batch hard negatives created by your curated batch
For a first implementation, I would start with (anchor, positive) and curated batches. After that works, add explicit hard-negative columns and compare.
Do not use gradient accumulation as a substitute for batch size
This is a common mistake.
For MNRL, the important thing is the contrastive batch size: the number of examples visible to the loss at the same time.
These are not equivalent:
per_device_train_batch_size = 32
gradient_accumulation_steps = 8
and:
per_device_train_batch_size = 256
The first setup may update the optimizer after 256 examples, but each MNRL softmax only sees 32 examples at once.
If batch size 256 does not fit in GPU memory, use CachedMultipleNegativesRankingLoss:
loss_fn = losses.CachedMultipleNegativesRankingLoss(
model,
mini_batch_size=32,
directions=(
"query_to_doc",
"query_to_query",
"doc_to_query",
"doc_to_doc",
),
partition_mode="joint",
)
Keep:
per_device_train_batch_size = 256
The distinction is:
per_device_train_batch_size = contrastive batch size
mini_batch_size = internal memory chunk size
That matters because your curated group of 256 addresses is the semantic training unit.
Multi-GPU caution
MultipleNegativesRankingLoss has a gather_across_devices option. It can increase the effective negative pool across devices, but it also changes the effective contrastive batch.
For your case, exact batch composition matters. I would first validate everything on one GPU:
single GPU
per_device_train_batch_size = 256
gather_across_devices = False
Then move to distributed training only after you have verified what examples are actually visible to each loss computation.
Practical address-data advice
Good positives
Use formatting variants of the same real address:
Blue Street, 1, New York
Blue Street 1 - New York
1 Blue St., NYC
Blue Street 1, New York, NY
Apt 2, 1 Blue Street, New York
1 Blue St Apartment 2, NYC
Good hard negatives
Use addresses that differ in identity-critical components:
same street + same city + different house number
same street + same house number + different city
same building + different apartment/unit
same street + different postal code
same house number + similar street name
same city + changed street suffix
changed directional: North Main St vs South Main St
Examples:
Blue Street 1, New York
Blue Street 11, New York
Blue Street 1, New York
Blue Street 1, Newark
Blue Street 1 Apt 2, New York
Blue Street 1 Apt 20, New York
North Main Street 10
South Main Street 10
Good batch design
Build each batch as a “confusion neighborhood”:
Batch theme:
same normalized street + same city
Rows:
Blue Street 1, New York ↔ 1 Blue St, NYC
Blue Street 11, New York ↔ 11 Blue St, NYC
Blue Street 1 Apt 2, New York ↔ Apt 2, 1 Blue St, NYC
Blue Street 1 Apt 20, New York ↔ Apt 20, 1 Blue St, NYC
Blue Avenue 1, New York ↔ 1 Blue Ave, NYC
Blue Street 1, Newark ↔ 1 Blue St, Newark
This is much better than random batching, because random negatives are often too easy.
Recommended preprocessing validation
Keep metadata before training:
training_rows = [
{
"batch_id": 0,
"anchor": "Blue Street, 1, New York",
"positive": "Blue Street 1 - New York",
"canonical_address_id": "addr_001",
},
{
"batch_id": 0,
"anchor": "Blue Street, 11, New York",
"positive": "Blue Street 11 - New York",
"canonical_address_id": "addr_002",
},
]
Validate each batch:
def validate_precomputed_batches(batches_with_ids: list[list[dict]]) -> None:
for batch_idx, batch in enumerate(batches_with_ids):
if len(batch) != BATCH_SIZE:
raise ValueError(
f"Batch {batch_idx} has {len(batch)} rows, expected {BATCH_SIZE}."
)
canonical_ids = [row["canonical_address_id"] for row in batch]
if len(canonical_ids) != len(set(canonical_ids)):
raise ValueError(
f"Batch {batch_idx} contains duplicate canonical address IDs. "
"This creates false negatives for MNRL."
)
for row_idx, row in enumerate(batch):
if not row["anchor"] or not row["positive"]:
raise ValueError(
f"Batch {batch_idx}, row {row_idx} has empty text."
)
Then build the trainer dataset with only text columns:
train_dataset = Dataset.from_dict(
{
"anchor": [row["anchor"] for batch in batches_with_ids for row in batch],
"positive": [row["positive"] for batch in batches_with_ids for row in batch],
}
).select_columns(["anchor", "positive"])
This gives you metadata safety during preprocessing and a clean text-only dataset during training.
Evaluation: do not rely only on average cosine similarity
For address embeddings, generic similarity evaluation is not enough.
Use at least these four evaluation types.
1. Same-address invariance
Pairs that should be close:
Blue Street 1, New York
1 Blue St., NYC
Measure:
positive cosine distribution
2. Hard-negative separation
Pairs that should not be too close:
Blue Street 1, New York
Blue Street 11, New York
Slice by component type:
house number changed
city changed
unit changed
postal code changed
street suffix changed
directional changed
3. Triplet accuracy
Triplets:
anchor:
Blue Street 1, New York
positive:
1 Blue St., NYC
negative:
Blue Street 11, New York
Measure:
score(anchor, positive) > score(anchor, negative)
Also measure the margin:
score(anchor, positive) - score(anchor, negative)
4. Retrieval
Corpus:
all canonical addresses
Query:
raw or noisy address variant
Measure:
Recall@1
Recall@5
Recall@10
MRR
nDCG
For a production address resolver, retrieval metrics are usually the most realistic. The main question is whether the correct canonical address appears in the top candidates.
Production design recommendation
For high-precision address matching, I would not rely only on one embedding model.
Use a two-stage architecture:
raw address
↓
bi-encoder embedding model
↓
top 10 / top 50 canonical candidates
↓
cross-encoder or structured verifier
↓
same / different / uncertain
The bi-encoder is fast and good for candidate retrieval. A cross-encoder or verifier can compare two addresses jointly and pay close attention to exact differences:
1 vs 11
NYC vs Newark
Apt 2 vs Apt 20
Street vs Avenue
North Main vs South Main
For address matching, this second stage is often what prevents high-cost false positives.
Practical experiment order
Experiment 1: curated batches with simple MNRL
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=("query_to_doc",),
)
This is the simplest baseline.
Experiment 2: symmetric MNRL
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=("query_to_doc", "doc_to_query"),
partition_mode="per_direction",
)
This is a good next step because address equivalence is symmetric.
Experiment 3: all four directions
loss_fn = losses.MultipleNegativesRankingLoss(
model,
directions=(
"query_to_doc",
"query_to_query",
"doc_to_query",
"doc_to_doc",
),
partition_mode="joint",
)
Use this only if your batch validation is strong and evaluation improves.
Experiment 4: explicit hard negatives
Use columns like:
anchor
positive
negative_1
negative_2
negative_3
Still preserve your curated batch structure.
Experiment 5: CachedMNRL
If memory is limiting, switch from MNRL to CachedMNRL while preserving per_device_train_batch_size=256.
Common mistakes to avoid
Mistake 1: putting batch_id into the training dataset
Keep metadata outside the trainer dataset unless you customize the collator/loss path.
Mistake 2: using random batches
Random negatives are mostly too easy. Your task needs hard negatives.
Mistake 3: trusting exact-string deduplication
Different strings can still be the same address.
Mistake 4: using gradient accumulation to simulate batch size 256
It does not create an MNRL negative pool of 256.
Mistake 5: training all four directions without strong false-negative checks
query_to_query and doc_to_doc make false negatives more damaging.
Final answer
Create a normal flattened datasets.Dataset with only the text columns needed by the loss. Preserve your handcrafted batch structure by row order. Then pass a custom DefaultBatchSampler or sampler factory through SentenceTransformerTrainingArguments(batch_sampler=...) so the trainer’s internal dataloader yields exactly your curated groups of 256 pairs.
In short:
Flatten your batches.
Keep rows ordered by curated batch.
Use a custom batch sampler.
Set per_device_train_batch_size=256.
Validate that no batch contains two equivalent canonical addresses.
Use CachedMNRL if memory is too tight.
Evaluate with retrieval and hard-negative address tests.