TarNet

Description

The TarNet class is a wrapper for a treatment effect estimation model using a shared representation (deconfounder) and treatment-specific outcome models. It integrates the TarNetBase model, data loading, training, validation, prediction, and evaluation functionality. The class also supports saving and loading model checkpoints.

Parameters

  • epochs (int, optional): Number of training epochs (default: 200).

  • batch_size (int, optional): Batch size for training (default: 32).

  • learning_rate (float, optional): Learning rate for the optimizer (default: 2e-5).

  • architecture_y (list, optional): Layer sizes for the outcome model (default: [1]).

  • architecture_z (list, optional): Layer sizes for the shared representation (default: [1024]).

  • dropout (float, optional): Dropout rate (default: 0.3).

  • step_size (int, optional): Step size for the learning rate scheduler (default: None).

  • bn (bool, optional): Whether to use batch normalization (default: False).

  • patience (int, optional): Number of epochs with no improvement before early stopping (default: 5).

  • min_delta (float, optional): Minimum improvement threshold for early stopping (default: 0.01).

  • model_dir (str, optional): Directory to save model checkpoints (default: None).

  • return_probablity (bool, optional): If True, model outputs probabilities (default: False).

  • verbose (bool, optional): If True, prints additional information during training (default: True).

Example Usage

from TarNet import TarNet

model = TarNet(
    epochs=100,
    batch_size=32,
    learning_rate=1e-4,
    architecture_y=[200, 1],
    architecture_z=[2048],
    dropout=0.2,
    bn=True,
    model_dir="./model_checkpoint"
)
model.fit(R, Y, T, valid_perc=0.2, plot_loss=True)
y0_preds, y1_preds, frs = model.predict(R)

Methods

fit

Purpose and Description:

Trains the TarNet model using internal representations (R), outcomes (Y), and treatment indicators (T). It performs a train/validation split, trains the model with early stopping and optional learning rate scheduling, and optionally plots the training and validation loss curves.

Arguments:
  • R (np.ndarray or torch.Tensor): Internal representations for all samples.

  • Y (np.ndarray or torch.Tensor): Outcome values.

  • T (np.ndarray or torch.Tensor): Treatment indicators.

  • valid_perc (float, optional): Fraction of the data used for validation.

  • plot_loss (bool, optional): If True, plots the loss curves (default: True).

Example:

model = TarNet(epochs=50)
model.fit(R, Y, T, valid_perc=0.2, plot_loss=True)

predict

Purpose and Description:

Processes internal representation data in batches and returns predictions. It outputs the predicted outcomes for the control (T=0) and treated (T=1) groups, as well as the latent representation extracted by the model.

Arguments:
  • r (np.ndarray or torch.Tensor): Internal representation data for prediction.

Returns:
  • y0_preds (torch.Tensor): Predicted outcomes for the control group.

  • y1_preds (torch.Tensor): Predicted outcomes for the treated group.

  • frs (torch.Tensor): Deconfounder extracted by the model.

Example:

y0, y1, fr = model.predict(R)
print("Control predictions:", y0)
print("Treated predictions:", y1)