load_hiddens

Description

The load_hiddens function loads hidden state tensors from disk. Given a directory and a list of file identifiers (without the .pt extension), it reads each tensor file (optionally using a prefix), stacks them together, and returns a tensor of hidden representations.

Arguments

  • directory (str): The directory where the hidden state files are stored.

  • hidden_list (list): List of file names (without the .pt extension) to be loaded.

  • prefix (str, optional): A prefix to add to each file name (default: None).

  • device (torch.device, optional): The device on which to load the tensors (default: "cpu").

Returns

  • tensors (torch.Tensor): A tensor containing the stacked hidden representations.

Example Usage

from gpi_pack.TNutil import load_hiddens

# Assume hidden files are named like "hidden_last_1.pt", "hidden_last_2.pt", etc.
hidden_list = [1, 2, 3]
hidden_states = load_hiddens(directory="./hidden_states", hidden_list=hidden_list, prefix="hidden_last_")
print("Loaded hidden states shape:", hidden_states.shape)