Custom batches in sentence-transformers for MultipleNegativesRankingLoss

I am using the sentence-transformers library to finetune a model. The goal is to generate embeddings for postal addresses so that embeddings for the same address written in different manners are close to each other.
However, addresses that only differ for a small part (e.g. the street number, or the name of the city) must have sufficiently different embeddings, which is not the case when I try to finetune the all-mpnet-base-v2 model using the CosineSimilarityLoss (or similars).

Therefore, I am trying to use the MultipleNegativesRankingLoss. As far as I understand, the computation of this loss function takes into account the whole minibatch, not just the individual pairs of sentences/addresses. It enforces not only that sentences/addresses in a given pair have similar embeddings, but also consider sentences/addresses in different pairs of the same batch as negatives (which is exactly what I need).

Therefore, I prepared a trainining set that is already partitioned in batches with 256 pairs each, taking care to put in the same batch pairs that must be considered strong negatives even if they are quite similar.

batches: list[tuple[tuple[str, str], 256]] = [
    (
        (batch1_anchor1,  batch1_positive1),  # ('Blue Street, 1, New York', 'Blue Street 1 - New York'), 
        (batch1_anchor2,  batch1_positive2),  # ('Blue Street, 11, New York', 'Blue Street 11 - New York'),
        (batch1_anchor3,  batch1_positive3),
        ...
    ),
    (
        (batch2_anchor1,  batch2_positive1),
        (batch2_anchor2,  batch2_positive2),
        (batch2_anchor3,  batch2_positive3),
        ...
    ),
    ....
]

My question is: how do I preserve this batch structure when loading the training data into the trainer?
The SentenceTransformerTrainer class only accepts a datasets.Dataset, I see no way to preserve my batches.

    loss_fn = MultipleNegativesRankingLoss(
        model, 
        directions=('query_to_doc', 'query_to_query', 'doc_to_query', 'doc_to_doc')
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=???,  # here I can pass a datasets.Dataset, not a torch.utils.data.DataLoader or equivalent
        loss=loss_fn,
     )

Hmm… Apparently, it’s possible using the custom batch sampler… but I’m not entirely sure if that’s true.
just in case, @tomaarsen


Custom batches in sentence-transformers for MultipleNegativesRankingLoss

Short answer

Use a normal datasets.Dataset, but do not rely on the default batching behavior of SentenceTransformerTrainer.

Instead:

  1. Flatten your pre-built batches into a single datasets.Dataset.
  2. Keep the rows ordered so that rows 0..255 are your first curated batch, rows 256..511 are your second curated batch, and so on.
  3. Pass a custom batch sampler through SentenceTransformerTrainingArguments(batch_sampler=...).
  4. Make the custom sampler yield exactly the row-index groups you want.

The key point is:

train_dataset = storage format
batch_sampler = batching policy
loss = uses the resulting batch as the contrastive pool

So your train_dataset=??? should be a flattened datasets.Dataset, and the batch-preserving logic should live in args.batch_sampler.


Why this is the right abstraction

Your understanding of MultipleNegativesRankingLoss is correct.

MultipleNegativesRankingLoss is an in-batch contrastive loss. For each anchor in a batch, the matching positive should be closer than the other candidate positives or documents in the same batch.

For example, suppose one minibatch contains:

[
    ("Blue Street, 1, New York",  "Blue Street 1 - New York"),
    ("Blue Street, 11, New York", "Blue Street 11 - New York"),
]

Then, for the first row, the model is trained to make:

"Blue Street, 1, New York"

closer to:

"Blue Street 1 - New York"

than to:

"Blue Street 11 - New York"

That is exactly the behavior you want.

With address matching, the hard part is not only learning that two variants of the same address are close. The hard part is learning that extremely similar-looking addresses may still be different real-world entities:

Blue Street 1, New York       ==  Blue St. 1, NYC
Blue Street 1, New York       !=  Blue Street 11, New York
Blue Street 1, New York       !=  Blue Street 1, Newark
Blue Street 1 Apt 2           !=  Blue Street 1 Apt 20

So your curated batch structure is not incidental. It is part of the supervision.


Why CosineSimilarityLoss is usually weaker here

CosineSimilarityLoss is pairwise. It sees one pair and a target score.

That works well when your labels are naturally pairwise:

pair A similarity = 0.95
pair B similarity = 0.10
pair C similarity = 0.60

But your real task is closer to ranking:

Given:
    "Blue Street, 1, New York"

Rank this highest:
    "Blue Street 1 - New York"

Rank these lower:
    "Blue Street 11 - New York"
    "Blue Street 1 - Newark"
    "Blue Avenue 1 - New York"
    "Blue Street 1 Apt 2 - New York"

That is why MultipleNegativesRankingLoss is a better fit. It turns each minibatch into a local retrieval problem.

In this sense, your task is not just sentence similarity. It is closer to:

postal-address entity resolution
+
dense retrieval
+
hard-negative metric learning

Why not pass a DataLoader?

The newer SentenceTransformerTrainer API expects a dataset, not a user-supplied PyTorch DataLoader.

That does not mean you cannot control batches.

The control point is SentenceTransformerTrainingArguments.batch_sampler, documented in the samplers reference. The sampler docs explain that a custom batch sampler can be supplied by subclassing DefaultBatchSampler or by passing a function that returns a DefaultBatchSampler instance.

So the correct structure is:

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,                 # contains batch_sampler
    train_dataset=train_dataset,
    loss=loss_fn,
)

not:

trainer = SentenceTransformerTrainer(
    model=model,
    train_dataloader=my_dataloader,  # not the intended API
)

Step 1: flatten your pre-built batches

Assume your current data looks conceptually like this:

batches = [
    [
        (batch1_anchor1, batch1_positive1),
        (batch1_anchor2, batch1_positive2),
        ...
    ],
    [
        (batch2_anchor1, batch2_positive1),
        (batch2_anchor2, batch2_positive2),
        ...
    ],
]

Flatten it while preserving order:

from datasets import Dataset

BATCH_SIZE = 256

flat_anchors: list[str] = []
flat_positives: list[str] = []

for batch_idx, batch in enumerate(batches):
    if len(batch) != BATCH_SIZE:
        raise ValueError(
            f"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}."
        )

    for anchor, positive in batch:
        flat_anchors.append(anchor)
        flat_positives.append(positive)

train_dataset = Dataset.from_dict(
    {
        "anchor": flat_anchors,
        "positive": flat_positives,
    }
)

# Keep column order explicit.
train_dataset = train_dataset.select_columns(["anchor", "positive"])

Now the dataset rows have this structure:

rows 0..255      = curated batch 0
rows 256..511    = curated batch 1
rows 512..767    = curated batch 2
...

The dataset itself is flat, but your precomputed batch structure is preserved by row position.


Important: do not pass metadata columns directly to the trainer

Sentence Transformers training datasets are column-order-sensitive. In the training overview, non-label columns are treated as model inputs.

So this is safe:

train_dataset = train_dataset.select_columns(["anchor", "positive"])

This is risky if passed directly to the trainer:

Dataset.from_dict(
    {
        "anchor": anchors,
        "positive": positives,
        "batch_id": batch_ids,
        "canonical_address_id": canonical_ids,
    }
)

because batch_id and canonical_address_id are metadata, not text inputs.

Keep metadata during preprocessing and validation, but remove it before training:

train_dataset = full_dataset.select_columns(["anchor", "positive"])

Step 2: define a custom batch sampler

This sampler yields contiguous blocks of indices:

[0, 1, 2, ..., 255]
[256, 257, 258, ..., 511]
[512, 513, 514, ..., 767]
...

It can shuffle the order of whole batches between epochs, but it never mixes examples across your curated batches.

from collections.abc import Iterator

import torch
from datasets import Dataset
from sentence_transformers.sampler import DefaultBatchSampler


class ExactPreBatchedSampler(DefaultBatchSampler):
    """
    Preserves precomputed contiguous batches.

    Assumption:
        Rows 0..255      are curated batch 0
        Rows 256..511    are curated batch 1
        Rows 512..767    are curated batch 2
        ...

    The sampler may shuffle the order of whole batches, but it never mixes
    rows from different precomputed batches.
    """

    def __init__(
        self,
        dataset: Dataset,
        batch_size: int,
        drop_last: bool,
        valid_label_columns: list[str] | None = None,
        generator: torch.Generator | None = None,
        seed: int = 0,
        shuffle_batches: bool = True,
    ) -> None:
        super().__init__(
            dataset=dataset,
            batch_size=batch_size,
            drop_last=drop_last,
            valid_label_columns=valid_label_columns,
            generator=generator,
            seed=seed,
        )
        self.dataset = dataset
        self.shuffle_batches = shuffle_batches

        if self.batch_size <= 0:
            raise ValueError(f"batch_size must be positive, got {self.batch_size}.")

        if len(self.dataset) < self.batch_size:
            raise ValueError(
                f"Dataset has {len(self.dataset)} rows, "
                f"but batch_size={self.batch_size}."
            )

    def __iter__(self) -> Iterator[list[int]]:
        # DefaultBatchSampler provides epoch handling via SetEpochMixin.
        if self.generator is not None and self.seed is not None:
            self.generator.manual_seed(self.seed + self.epoch)

        n_full_batches = len(self.dataset) // self.batch_size
        remainder_start = n_full_batches * self.batch_size

        batch_ids = torch.arange(n_full_batches)

        if self.shuffle_batches:
            batch_ids = batch_ids[
                torch.randperm(n_full_batches, generator=self.generator)
            ]

        for batch_id in batch_ids.tolist():
            start = batch_id * self.batch_size
            end = start + self.batch_size
            yield list(range(start, end))

        if not self.drop_last and remainder_start < len(self.dataset):
            yield list(range(remainder_start, len(self.dataset)))

    def __len__(self) -> int:
        n_full_batches = len(self.dataset) // self.batch_size
        has_remainder = len(self.dataset) % self.batch_size != 0
        return n_full_batches + int(has_remainder and not self.drop_last)

Step 3: pass the sampler through SentenceTransformerTrainingArguments

Use a small factory function. This is convenient because the trainer constructs the sampler internally and supplies arguments such as dataset, batch_size, drop_last, generator, and seed.

def exact_prebatched_sampler_factory(
    dataset: Dataset,
    batch_size: int,
    drop_last: bool,
    valid_label_columns: list[str] | None = None,
    generator: torch.Generator | None = None,
    seed: int = 0,
):
    if batch_size != BATCH_SIZE:
        raise ValueError(
            f"Expected batch_size={BATCH_SIZE}, got {batch_size}. "
            "Use per_device_train_batch_size=256."
        )

    return ExactPreBatchedSampler(
        dataset=dataset,
        batch_size=batch_size,
        drop_last=drop_last,
        valid_label_columns=valid_label_columns,
        generator=generator,
        seed=seed,
        shuffle_batches=True,
    )

Then configure the trainer:

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
)

BATCH_SIZE = 256

model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=(
        "query_to_doc",
        "query_to_query",
        "doc_to_query",
        "doc_to_doc",
    ),
    partition_mode="joint",
)

args = SentenceTransformerTrainingArguments(
    output_dir="models/address-mpnet-mnrl",

    # Must match your curated batch size.
    per_device_train_batch_size=BATCH_SIZE,

    # Usually safest if all curated batches are exactly size 256.
    dataloader_drop_last=True,

    # Critical part: preserve your precomputed batches.
    batch_sampler=exact_prebatched_sampler_factory,

    # Usual training settings. Tune these for your dataset.
    num_train_epochs=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,

    # Use bf16 if your hardware supports it. Otherwise use fp16 or fp32.
    bf16=True,
    fp16=False,

    logging_steps=50,
    save_steps=500,
    save_total_limit=2,
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss_fn,
)

trainer.train()

That is the core solution.


Complete minimal example

from collections.abc import Iterator

import torch
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
)
from sentence_transformers.sampler import DefaultBatchSampler


BATCH_SIZE = 256


class ExactPreBatchedSampler(DefaultBatchSampler):
    def __init__(
        self,
        dataset: Dataset,
        batch_size: int,
        drop_last: bool,
        valid_label_columns: list[str] | None = None,
        generator: torch.Generator | None = None,
        seed: int = 0,
        shuffle_batches: bool = True,
    ) -> None:
        super().__init__(
            dataset=dataset,
            batch_size=batch_size,
            drop_last=drop_last,
            valid_label_columns=valid_label_columns,
            generator=generator,
            seed=seed,
        )
        self.dataset = dataset
        self.shuffle_batches = shuffle_batches

    def __iter__(self) -> Iterator[list[int]]:
        if self.generator is not None and self.seed is not None:
            self.generator.manual_seed(self.seed + self.epoch)

        n_full_batches = len(self.dataset) // self.batch_size
        remainder_start = n_full_batches * self.batch_size

        batch_ids = torch.arange(n_full_batches)

        if self.shuffle_batches:
            batch_ids = batch_ids[
                torch.randperm(n_full_batches, generator=self.generator)
            ]

        for batch_id in batch_ids.tolist():
            start = batch_id * self.batch_size
            end = start + self.batch_size
            yield list(range(start, end))

        if not self.drop_last and remainder_start < len(self.dataset):
            yield list(range(remainder_start, len(self.dataset)))

    def __len__(self) -> int:
        n_full_batches = len(self.dataset) // self.batch_size
        has_remainder = len(self.dataset) % self.batch_size != 0
        return n_full_batches + int(has_remainder and not self.drop_last)


def exact_prebatched_sampler_factory(
    dataset: Dataset,
    batch_size: int,
    drop_last: bool,
    valid_label_columns: list[str] | None = None,
    generator: torch.Generator | None = None,
    seed: int = 0,
):
    if batch_size != BATCH_SIZE:
        raise ValueError(
            f"Expected batch_size={BATCH_SIZE}, got {batch_size}. "
            "Use per_device_train_batch_size=256."
        )

    return ExactPreBatchedSampler(
        dataset=dataset,
        batch_size=batch_size,
        drop_last=drop_last,
        valid_label_columns=valid_label_columns,
        generator=generator,
        seed=seed,
        shuffle_batches=True,
    )


flat_anchors: list[str] = []
flat_positives: list[str] = []

for batch_idx, batch in enumerate(batches):
    if len(batch) != BATCH_SIZE:
        raise ValueError(
            f"Batch {batch_idx} has {len(batch)} pairs, expected {BATCH_SIZE}."
        )

    for anchor, positive in batch:
        flat_anchors.append(anchor)
        flat_positives.append(positive)

train_dataset = Dataset.from_dict(
    {
        "anchor": flat_anchors,
        "positive": flat_positives,
    }
).select_columns(["anchor", "positive"])

model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=(
        "query_to_doc",
        "query_to_query",
        "doc_to_query",
        "doc_to_doc",
    ),
    partition_mode="joint",
)

args = SentenceTransformerTrainingArguments(
    output_dir="models/address-mpnet-mnrl",
    per_device_train_batch_size=BATCH_SIZE,
    dataloader_drop_last=True,
    batch_sampler=exact_prebatched_sampler_factory,
    num_train_epochs=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    fp16=False,
    logging_steps=50,
    save_steps=500,
    save_total_limit=2,
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss_fn,
)

trainer.train()

Verify the sampler before training

Before launching a long training job, inspect a few batches.

sampler = exact_prebatched_sampler_factory(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    drop_last=True,
    generator=torch.Generator().manual_seed(42),
    seed=42,
)

for batch_number, indices in zip(range(5), sampler):
    print(
        "batch_number:",
        batch_number,
        "first_index:",
        indices[0],
        "last_index:",
        indices[-1],
        "size:",
        len(indices),
    )

    print("first anchor:", train_dataset[indices[0]]["anchor"])
    print("first positive:", train_dataset[indices[0]]["positive"])
    print()

Expected shape:

batch_number: 0 first_index: 512 last_index: 767 size: 256
batch_number: 1 first_index: 0   last_index: 255 size: 256
batch_number: 2 first_index: 256 last_index: 511 size: 256

The order may differ because whole batches are shuffled, but each yielded batch should still be a contiguous block of 256 rows.


Address-specific warning: false negatives

The main risk with MultipleNegativesRankingLoss is false negatives.

In MNRL, other positives in the same batch are treated as negatives for the current anchor. That is useful only if those other positives are truly different addresses.

This is good:

anchor:
    Blue Street, 1, New York

positive:
    Blue Street 1 - New York

in-batch negative:
    Blue Street 11 - New York

This is dangerous:

anchor:
    Blue Street, 1, New York

positive:
    Blue Street 1 - New York

in-batch negative from another row:
    1 Blue St., NYC

because 1 Blue St., NYC may be the same real address.

So your batch builder should enforce a rule like:

No two different rows in the same MNRL batch may refer to the same canonical address.

Exact string deduplication is not enough. Prefer deduplication by one or more of:

canonical address ID
delivery point ID
authoritative normalized address
geocoder result ID
parcel/building/unit ID
high-confidence rooftop coordinate

Be careful with all four directions

You used:

directions=(
    "query_to_doc",
    "query_to_query",
    "doc_to_query",
    "doc_to_doc",
)

That is supported by the current MultipleNegativesRankingLoss API and can provide a stronger signal.

However, it also makes batch cleanliness stricter.

With only query_to_doc, the main requirement is:

anchor_i should not match positive_j for i != j

With query_to_query, you also need:

anchor_i should not be equivalent to anchor_j

With doc_to_doc, you also need:

positive_i should not be equivalent to positive_j

For all four directions, validate:

anchor_i is not equivalent to anchor_j
positive_i is not equivalent to positive_j
anchor_i is not equivalent to positive_j for i != j
positive_i is not equivalent to anchor_j for i != j

If your canonical-address checks are not strong yet, consider starting with a simpler loss configuration:

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=("query_to_doc", "doc_to_query"),
    partition_mode="per_direction",
)

Then compare against the all-four-direction version.


Consider explicit hard negatives

Your current format is:

(anchor, positive)

That is valid for MNRL.

But if you already know specific hard negatives, you can also use:

(anchor, positive, negative_1, negative_2, negative_3)

MultipleNegativesRankingLoss supports pairs, triplets, and n-tuples:

train_dataset = Dataset.from_dict(
    {
        "anchor": [
            "Blue Street, 1, New York",
            "Blue Street, 11, New York",
        ],
        "positive": [
            "Blue Street 1 - New York",
            "Blue Street 11 - New York",
        ],
        "negative_1": [
            "Blue Street 11 - New York",
            "Blue Street 1 - New York",
        ],
        "negative_2": [
            "Blue Street 1 - Newark",
            "Blue Street 11 - Newark",
        ],
    }
).select_columns(["anchor", "positive", "negative_1", "negative_2"])

This gives the loss both:

explicit hard negatives attached to each row
+
in-batch hard negatives created by your curated batch

For a first implementation, I would start with (anchor, positive) and curated batches. After that works, add explicit hard-negative columns and compare.


Do not use gradient accumulation as a substitute for batch size

This is a common mistake.

For MNRL, the important thing is the contrastive batch size: the number of examples visible to the loss at the same time.

These are not equivalent:

per_device_train_batch_size = 32
gradient_accumulation_steps = 8

and:

per_device_train_batch_size = 256

The first setup may update the optimizer after 256 examples, but each MNRL softmax only sees 32 examples at once.

If batch size 256 does not fit in GPU memory, use CachedMultipleNegativesRankingLoss:

loss_fn = losses.CachedMultipleNegativesRankingLoss(
    model,
    mini_batch_size=32,
    directions=(
        "query_to_doc",
        "query_to_query",
        "doc_to_query",
        "doc_to_doc",
    ),
    partition_mode="joint",
)

Keep:

per_device_train_batch_size = 256

The distinction is:

per_device_train_batch_size = contrastive batch size
mini_batch_size             = internal memory chunk size

That matters because your curated group of 256 addresses is the semantic training unit.


Multi-GPU caution

MultipleNegativesRankingLoss has a gather_across_devices option. It can increase the effective negative pool across devices, but it also changes the effective contrastive batch.

For your case, exact batch composition matters. I would first validate everything on one GPU:

single GPU
per_device_train_batch_size = 256
gather_across_devices = False

Then move to distributed training only after you have verified what examples are actually visible to each loss computation.


Practical address-data advice

Good positives

Use formatting variants of the same real address:

Blue Street, 1, New York
Blue Street 1 - New York

1 Blue St., NYC
Blue Street 1, New York, NY

Apt 2, 1 Blue Street, New York
1 Blue St Apartment 2, NYC

Good hard negatives

Use addresses that differ in identity-critical components:

same street + same city + different house number
same street + same house number + different city
same building + different apartment/unit
same street + different postal code
same house number + similar street name
same city + changed street suffix
changed directional: North Main St vs South Main St

Examples:

Blue Street 1, New York
Blue Street 11, New York

Blue Street 1, New York
Blue Street 1, Newark

Blue Street 1 Apt 2, New York
Blue Street 1 Apt 20, New York

North Main Street 10
South Main Street 10

Good batch design

Build each batch as a “confusion neighborhood”:

Batch theme:
    same normalized street + same city

Rows:
    Blue Street 1, New York        ↔ 1 Blue St, NYC
    Blue Street 11, New York       ↔ 11 Blue St, NYC
    Blue Street 1 Apt 2, New York  ↔ Apt 2, 1 Blue St, NYC
    Blue Street 1 Apt 20, New York ↔ Apt 20, 1 Blue St, NYC
    Blue Avenue 1, New York        ↔ 1 Blue Ave, NYC
    Blue Street 1, Newark          ↔ 1 Blue St, Newark

This is much better than random batching, because random negatives are often too easy.


Recommended preprocessing validation

Keep metadata before training:

training_rows = [
    {
        "batch_id": 0,
        "anchor": "Blue Street, 1, New York",
        "positive": "Blue Street 1 - New York",
        "canonical_address_id": "addr_001",
    },
    {
        "batch_id": 0,
        "anchor": "Blue Street, 11, New York",
        "positive": "Blue Street 11 - New York",
        "canonical_address_id": "addr_002",
    },
]

Validate each batch:

def validate_precomputed_batches(batches_with_ids: list[list[dict]]) -> None:
    for batch_idx, batch in enumerate(batches_with_ids):
        if len(batch) != BATCH_SIZE:
            raise ValueError(
                f"Batch {batch_idx} has {len(batch)} rows, expected {BATCH_SIZE}."
            )

        canonical_ids = [row["canonical_address_id"] for row in batch]

        if len(canonical_ids) != len(set(canonical_ids)):
            raise ValueError(
                f"Batch {batch_idx} contains duplicate canonical address IDs. "
                "This creates false negatives for MNRL."
            )

        for row_idx, row in enumerate(batch):
            if not row["anchor"] or not row["positive"]:
                raise ValueError(
                    f"Batch {batch_idx}, row {row_idx} has empty text."
                )

Then build the trainer dataset with only text columns:

train_dataset = Dataset.from_dict(
    {
        "anchor": [row["anchor"] for batch in batches_with_ids for row in batch],
        "positive": [row["positive"] for batch in batches_with_ids for row in batch],
    }
).select_columns(["anchor", "positive"])

This gives you metadata safety during preprocessing and a clean text-only dataset during training.


Evaluation: do not rely only on average cosine similarity

For address embeddings, generic similarity evaluation is not enough.

Use at least these four evaluation types.

1. Same-address invariance

Pairs that should be close:

Blue Street 1, New York
1 Blue St., NYC

Measure:

positive cosine distribution

2. Hard-negative separation

Pairs that should not be too close:

Blue Street 1, New York
Blue Street 11, New York

Slice by component type:

house number changed
city changed
unit changed
postal code changed
street suffix changed
directional changed

3. Triplet accuracy

Triplets:

anchor:
    Blue Street 1, New York

positive:
    1 Blue St., NYC

negative:
    Blue Street 11, New York

Measure:

score(anchor, positive) > score(anchor, negative)

Also measure the margin:

score(anchor, positive) - score(anchor, negative)

4. Retrieval

Corpus:

all canonical addresses

Query:

raw or noisy address variant

Measure:

Recall@1
Recall@5
Recall@10
MRR
nDCG

For a production address resolver, retrieval metrics are usually the most realistic. The main question is whether the correct canonical address appears in the top candidates.


Production design recommendation

For high-precision address matching, I would not rely only on one embedding model.

Use a two-stage architecture:

raw address
    ↓
bi-encoder embedding model
    ↓
top 10 / top 50 canonical candidates
    ↓
cross-encoder or structured verifier
    ↓
same / different / uncertain

The bi-encoder is fast and good for candidate retrieval. A cross-encoder or verifier can compare two addresses jointly and pay close attention to exact differences:

1 vs 11
NYC vs Newark
Apt 2 vs Apt 20
Street vs Avenue
North Main vs South Main

For address matching, this second stage is often what prevents high-cost false positives.


Practical experiment order

Experiment 1: curated batches with simple MNRL

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=("query_to_doc",),
)

This is the simplest baseline.

Experiment 2: symmetric MNRL

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=("query_to_doc", "doc_to_query"),
    partition_mode="per_direction",
)

This is a good next step because address equivalence is symmetric.

Experiment 3: all four directions

loss_fn = losses.MultipleNegativesRankingLoss(
    model,
    directions=(
        "query_to_doc",
        "query_to_query",
        "doc_to_query",
        "doc_to_doc",
    ),
    partition_mode="joint",
)

Use this only if your batch validation is strong and evaluation improves.

Experiment 4: explicit hard negatives

Use columns like:

anchor
positive
negative_1
negative_2
negative_3

Still preserve your curated batch structure.

Experiment 5: CachedMNRL

If memory is limiting, switch from MNRL to CachedMNRL while preserving per_device_train_batch_size=256.


Common mistakes to avoid

Mistake 1: putting batch_id into the training dataset

Keep metadata outside the trainer dataset unless you customize the collator/loss path.

Mistake 2: using random batches

Random negatives are mostly too easy. Your task needs hard negatives.

Mistake 3: trusting exact-string deduplication

Different strings can still be the same address.

Mistake 4: using gradient accumulation to simulate batch size 256

It does not create an MNRL negative pool of 256.

Mistake 5: training all four directions without strong false-negative checks

query_to_query and doc_to_doc make false negatives more damaging.


Final answer

Create a normal flattened datasets.Dataset with only the text columns needed by the loss. Preserve your handcrafted batch structure by row order. Then pass a custom DefaultBatchSampler or sampler factory through SentenceTransformerTrainingArguments(batch_sampler=...) so the trainer’s internal dataloader yields exactly your curated groups of 256 pairs.

In short:

Flatten your batches.
Keep rows ordered by curated batch.
Use a custom batch sampler.
Set per_device_train_batch_size=256.
Validate that no batch contains two equivalent canonical addresses.
Use CachedMNRL if memory is too tight.
Evaluate with retrieval and hard-negative address tests.

Thank you a lot for your answer, I’ll read it carefully and try to implement it.
Since it appears to be (at least partially) LLM-generated, I hope you don’t mind a couple of clarifying questions (I’m new to hf, and I am not sure what’s the ai-generated content policy here):

  1. Did you review the solution yourself? If so, are you familiar with the sentence-transformers library?

  2. Which LLM did you use to generate the answer? I’ve asked ChatGPT 5.4 several times about this issue, but it consistently suggests writing a custom training loop that bypasses SentenceTransformerTrainer and its additional features.

Thanks again!

Sorry if you’re the type of person who hates LLM responses…

I am not sure what’s the ai-generated content policy here

Me too actually…

Did you review the solution yourself?

I haven’t actually tested this myself this time. I just looked at past cases. When I’m unsure about an answer, I often run the code myself to verify it. But this time, not yet.

are you familiar with the sentence-transformers library?

I use it on a daily basis, but I’m not very familiar with fine-tuning models of SentenceTransformers.

Which LLM did you use to generate the answer?

ChatGPT 5.5 (GPT-5.5 Thinking) + A collection of simple knowledge MD files stored locally on my computer. Also, these aren’t answers I got in a single go.