SpectralNormClassifier
Description
The SpectralNormClassifier
class implements a feed-forward neural network for multi-class classification with spectral normalization. It also works for binary classification when num_classes
is set to 2. This classifier applies spectral normalization to each linear layer in order to control the Lipschitz constant and improve training stability. Its architecture is a multi-layer perceptron (MLP) that can optionally include batch normalization and dropout in each hidden layer.
Parameters
input_dim (int): Number of input features in the data (dimension of X).
hidden_sizes (list of int, optional): Sizes of the hidden layers. Defaults to
[128, 64]
.num_classes (int, optional): Number of output classes. Defaults to
2
(binary classification).n_power_iterations (int, optional): Number of power iterations for computing the spectral norm in each layer. Defaults to
1
.dropout (float, optional): Dropout probability for each layer. If
0.0
, no dropout is applied. Defaults to0.0
.batch_norm (bool, optional): Whether to add a batch normalization layer after each linear layer. Defaults to
False
.lr (float, optional): Learning rate for the Adam optimizer. Defaults to
2e-6
.nepoch (int, optional): Maximum number of training epochs. Defaults to
20
.batch_size (int, optional): Batch size used during training. Defaults to
32
.patience (int, optional): Patience (in epochs) for early stopping on the validation set. Defaults to
5
.min_delta (float, optional): Minimum improvement in validation loss required to reset patience. Defaults to
1e-4
.use_scheduler (bool, optional): Whether to use a learning rate scheduler (e.g., StepLR or ReduceLROnPlateau). Defaults to
False
.scheduler_type (str, optional): Scheduler type:
"StepLR"
or"ReduceLROnPlateau"
. Defaults to"ReduceLROnPlateau"
.step_size (int, optional): Step size for the scheduler. Interpreted differently depending on
scheduler_type
(default:5
).gamma (float, optional): Learning rate decay factor used by the scheduler. Defaults to
0.5
.valid_perc (float, optional): Proportion of data to use for validation (train/valid split). Defaults to
0.2
.
Example Usage
from gpi_pack.TNutil import SpectralNormClassifier
# Initialize the classifier
model = SpectralNormClassifier(input_dim=20, hidden_sizes=[64, 32], num_classes=2)
# Train the classifier
model.fit(deconfounder, T)
# Predict class probabilities
probs = model.predict_proba(deconfounder)
print("Predicted probabilities:", probs)
# Predict hard classes
predictions = model.predict(deconfounder)
print("Predicted classes:", predictions)
Methods
fit
- Purpose and Description:
Trains the
SpectralNormClassifier
on the given dataset. It automatically splits the data into training and validation sets, performs training over a specified number of epochs, applies early stopping based on the validation loss, and optionally uses a learning rate scheduler.- Arguments:
X (np.ndarray): Input data of shape [n_samples, input_dim].
y (np.ndarray): Target class labels of shape [n_samples] (values in [0, num_classes-1]).
Example:
# Training the classifier
model.fit(deconfounder_train, T_train)
predict_proba
- Purpose and Description:
Computes the class probabilities for each sample by performing a forward pass and then applying softmax to the output logits.
- Arguments:
X (np.ndarray): Input data of shape [n_samples, input_dim].
- Returns:
probs (np.ndarray): Array of shape [n_samples, num_classes] containing the predicted class probabilities.
Example:
probs = model.predict_proba(deconfounder_test)
print("Predicted probabilities:", probs)
predict
- Purpose and Description:
Provides hard class predictions by selecting the class with the highest predicted probability for each sample.
- Arguments:
X (np.ndarray): Input data of shape [n_samples, input_dim].
- Returns:
predictions (np.ndarray): Array of shape [n_samples] with predicted class labels (ranging from 0 to num_classes-1).
Example:
predictions = model.predict(deconfounder_test)
print("Predicted classes:", predictions)