TarNet_loss

Description

The TarNet_loss function calculates the loss for the TarNet model.

Arguments

  • y_true (torch.Tensor): The true outcome values. For categorical outcomes, these are class indices.

  • t_true (torch.Tensor): The true treatment indicators (0 or 1).

  • y0_pred (torch.Tensor): The predicted outcomes for the control group.

  • y1_pred (torch.Tensor): The predicted outcomes for the treated group.

Returns

  • loss (torch.Tensor): A scalar tensor representing the combined loss.

Example Usage

from TNutil import TarNet_loss

loss = TarNet_loss(y_true, t_true, y0_pred, y1_pred)
print("TarNet Loss:", loss.item())