Skip to content

BIOEncodingF1Score

npfl138.metrics.BIOEncodingF1Score

Bases: Module, Metric

Metric for evaluating F1 score of BIO-encoded spans.

The metric employs a simple heuristic to handle invalid sequences of BIO tags. Notably:

  • If there is an I tag without preceding B/I tag, it is considered a B tag.
  • If the type of an I tag does not match the type of the preceding tag, the type of this I tag is ignored (i.e., considered the same as the preceding tag type).
Source code in npfl138/metrics/bio_encoding_f1_score.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class BIOEncodingF1Score(torch.nn.Module, Metric):
    """Metric for evaluating F1 score of BIO-encoded spans.

    The metric employs a simple heuristic to handle invalid sequences of BIO tags.
    Notably:

    - If there is an `I` tag without preceding `B/I` tag, it is considered a `B` tag.
    - If the type of an `I` tag does not match the type of the preceding tag, the type
      of this `I` tag is ignored (i.e., considered the same as the preceding tag type).
    """
    def __init__(self, labels: list[str], ignore_index: int) -> None:
        """Construct a new BIOEncodingF1Score metric.

        Parameters:
          labels: The list of BIO-encoded labels.
          ignore_index: The gold index to ignore when computing the F1 score.
        """
        super().__init__()
        self.register_buffer("tp", torch.tensor(0, dtype=torch.int64), persistent=False)
        self.register_buffer("fp", torch.tensor(0, dtype=torch.int64), persistent=False)
        self.register_buffer("fn", torch.tensor(0, dtype=torch.int64), persistent=False)
        self._labels = labels
        self._ignore_index = ignore_index

    def reset(self) -> Self:
        """Reset the metric to its initial state.

        Returns:
          self
        """
        self.tp.zero_()
        self.fp.zero_()
        self.fn.zero_()
        return self

    def update(self, pred: torch.Tensor, true: torch.Tensor) -> Self:
        """Update the metric with new predictions and targets.

        Returns:
          self
        """
        true = torch.nn.functional.pad(true, (0, 1), value=self._ignore_index).view(-1)
        pred = torch.nn.functional.pad(pred, (0, 1), value=self._ignore_index).view(-1)
        spans_pred, spans_true = set(), set()
        for spans, tags in [(spans_true, true), (spans_pred, pred)]:
            span, offset, start = None, 0, None
            for tag in tags:
                label = self._labels[tag] if tag != self._ignore_index else "O"
                if span and label.startswith(("O", "B")):
                    spans.add((start, offset, span))
                    span = None
                if not span and label.startswith(("B", "I")):
                    span, start = label[1:], offset
                if tag != self._ignore_index:
                    offset += 1
        self.tp.add_(len(spans_pred & spans_true))
        self.fp.add_(len(spans_pred - spans_true))
        self.fn.add_(len(spans_true - spans_pred))
        return self

    def compute(self) -> torch.Tensor:
        """Compute the F1 score."""
        return 2 * self.tp / torch.max(2 * self.tp + self.fp + self.fn, torch.ones_like(self.tp))

__init__

__init__(labels: list[str], ignore_index: int) -> None

Construct a new BIOEncodingF1Score metric.

Parameters:

  • labels (list[str]) –

    The list of BIO-encoded labels.

  • ignore_index (int) –

    The gold index to ignore when computing the F1 score.

Source code in npfl138/metrics/bio_encoding_f1_score.py
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, labels: list[str], ignore_index: int) -> None:
    """Construct a new BIOEncodingF1Score metric.

    Parameters:
      labels: The list of BIO-encoded labels.
      ignore_index: The gold index to ignore when computing the F1 score.
    """
    super().__init__()
    self.register_buffer("tp", torch.tensor(0, dtype=torch.int64), persistent=False)
    self.register_buffer("fp", torch.tensor(0, dtype=torch.int64), persistent=False)
    self.register_buffer("fn", torch.tensor(0, dtype=torch.int64), persistent=False)
    self._labels = labels
    self._ignore_index = ignore_index

reset

reset() -> Self

Reset the metric to its initial state.

Returns:

Source code in npfl138/metrics/bio_encoding_f1_score.py
37
38
39
40
41
42
43
44
45
46
def reset(self) -> Self:
    """Reset the metric to its initial state.

    Returns:
      self
    """
    self.tp.zero_()
    self.fp.zero_()
    self.fn.zero_()
    return self

update

update(pred: Tensor, true: Tensor) -> Self

Update the metric with new predictions and targets.

Returns:

Source code in npfl138/metrics/bio_encoding_f1_score.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def update(self, pred: torch.Tensor, true: torch.Tensor) -> Self:
    """Update the metric with new predictions and targets.

    Returns:
      self
    """
    true = torch.nn.functional.pad(true, (0, 1), value=self._ignore_index).view(-1)
    pred = torch.nn.functional.pad(pred, (0, 1), value=self._ignore_index).view(-1)
    spans_pred, spans_true = set(), set()
    for spans, tags in [(spans_true, true), (spans_pred, pred)]:
        span, offset, start = None, 0, None
        for tag in tags:
            label = self._labels[tag] if tag != self._ignore_index else "O"
            if span and label.startswith(("O", "B")):
                spans.add((start, offset, span))
                span = None
            if not span and label.startswith(("B", "I")):
                span, start = label[1:], offset
            if tag != self._ignore_index:
                offset += 1
    self.tp.add_(len(spans_pred & spans_true))
    self.fp.add_(len(spans_pred - spans_true))
    self.fn.add_(len(spans_true - spans_pred))
    return self

compute

compute() -> Tensor

Compute the F1 score.

Source code in npfl138/metrics/bio_encoding_f1_score.py
73
74
75
def compute(self) -> torch.Tensor:
    """Compute the F1 score."""
    return 2 * self.tp / torch.max(2 * self.tp + self.fp + self.fn, torch.ones_like(self.tp))