Skip to content

Binary Cross Entropy Loss

BinLoss

Bases: Module

Binary cross-entropy loss for hard and soft attention.

Attributes None

Methods forward: Computes the binary cross-entropy loss for hard and soft attention.

Source code in training/loss/bin_loss.py
class BinLoss(Module):
    r"""Binary cross-entropy loss for hard and soft attention.

    Attributes
        None

    Methods
        forward: Computes the binary cross-entropy loss for hard and soft attention.

    """

    def __init__(self):
        super().__init__()

    def forward(
        self, hard_attention: torch.Tensor, soft_attention: torch.Tensor,
    ) -> torch.Tensor:
        r"""Computes the binary cross-entropy loss for hard and soft attention.

        Args:
            hard_attention (torch.Tensor): A binary tensor indicating the hard attention.
            soft_attention (torch.Tensor): A tensor containing the soft attention probabilities.

        Returns:
            torch.Tensor: The binary cross-entropy loss.

        """
        log_sum = torch.log(
            torch.clamp(soft_attention[hard_attention == 1], min=1e-12),
        ).sum()
        return -log_sum / hard_attention.sum()

forward(hard_attention, soft_attention)

Computes the binary cross-entropy loss for hard and soft attention.

Parameters:

Name Type Description Default
hard_attention Tensor

A binary tensor indicating the hard attention.

required
soft_attention Tensor

A tensor containing the soft attention probabilities.

required

Returns:

Type Description
Tensor

torch.Tensor: The binary cross-entropy loss.

Source code in training/loss/bin_loss.py
def forward(
    self, hard_attention: torch.Tensor, soft_attention: torch.Tensor,
) -> torch.Tensor:
    r"""Computes the binary cross-entropy loss for hard and soft attention.

    Args:
        hard_attention (torch.Tensor): A binary tensor indicating the hard attention.
        soft_attention (torch.Tensor): A tensor containing the soft attention probabilities.

    Returns:
        torch.Tensor: The binary cross-entropy loss.

    """
    log_sum = torch.log(
        torch.clamp(soft_attention[hard_attention == 1], min=1e-12),
    ).sum()
    return -log_sum / hard_attention.sum()