Skip to content

Libri TTS dataset vocoder

LibriTTSDatasetVocoder

Bases: Dataset

Loading preprocessed univnet model data.

Source code in training/datasets/libritts_dataset_vocoder.py
class LibriTTSDatasetVocoder(Dataset):
    r"""Loading preprocessed univnet model data."""

    def __init__(
        self,
        root: str,
        batch_size: int,
        download: bool = True,
        lang: str = "en",
    ):
        r"""A PyTorch dataset for loading preprocessed univnet data.

        Args:
            root (str): Path to the directory where the dataset is found or downloaded.
            batch_size (int): Batch size for the dataset.
            download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
        """
        self.dataset = datasets.LIBRITTS(root=root, download=download)
        self.batch_size = batch_size

        lang_map = get_lang_map(lang)
        self.preprocess_libtts = PreprocessLibriTTS(
            PreprocessingConfigUnivNet(lang_map.processing_lang_type),
        )

    def __len__(self) -> int:
        r"""Returns the number of samples in the dataset.

        Returns
            int: Number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r"""Returns a sample from the dataset at the given index.

        Args:
            idx (int): Index of the sample to return.

        Returns:
            Dict[str, Any]: A dictionary containing the sample data.
        """
        # Retrive the dataset row
        data = self.dataset[idx]

        data = self.preprocess_libtts.univnet(data)

        if data is None:
            # print("Skipping due to preprocessing error")
            rand_idx = np.random.randint(0, self.__len__())
            return self.__getitem__(rand_idx)

        mel, audio, speaker_id = data

        return {
            "mel": mel,
            "audio": audio,
            "speaker_id": speaker_id,
        }

    def collate_fn(self, data: List) -> List:
        r"""Collates a batch of data samples.

        Args:
            data (List): A list of data samples.

        Returns:
            List: A list of reprocessed data batches.
        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(4)]
        (
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]

            mels.append(data_entry["mel"])
            mel_lens.append(data_entry["mel"].shape[1])
            audios.append(data_entry["audio"])
            speaker_ids.append(data_entry["speaker_id"])

        mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
        mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
        audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
        speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

        return [
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ]

__getitem__(idx)

Returns a sample from the dataset at the given index.

Parameters:

Name Type Description Default
idx int

Index of the sample to return.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the sample data.

Source code in training/datasets/libritts_dataset_vocoder.py
def __getitem__(self, idx: int) -> Dict[str, Any]:
    r"""Returns a sample from the dataset at the given index.

    Args:
        idx (int): Index of the sample to return.

    Returns:
        Dict[str, Any]: A dictionary containing the sample data.
    """
    # Retrive the dataset row
    data = self.dataset[idx]

    data = self.preprocess_libtts.univnet(data)

    if data is None:
        # print("Skipping due to preprocessing error")
        rand_idx = np.random.randint(0, self.__len__())
        return self.__getitem__(rand_idx)

    mel, audio, speaker_id = data

    return {
        "mel": mel,
        "audio": audio,
        "speaker_id": speaker_id,
    }

__init__(root, batch_size, download=True, lang='en')

A PyTorch dataset for loading preprocessed univnet data.

Parameters:

Name Type Description Default
root str

Path to the directory where the dataset is found or downloaded.

required
batch_size int

Batch size for the dataset.

required
download bool

Whether to download the dataset if it is not found. Defaults to True.

True
Source code in training/datasets/libritts_dataset_vocoder.py
def __init__(
    self,
    root: str,
    batch_size: int,
    download: bool = True,
    lang: str = "en",
):
    r"""A PyTorch dataset for loading preprocessed univnet data.

    Args:
        root (str): Path to the directory where the dataset is found or downloaded.
        batch_size (int): Batch size for the dataset.
        download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
    """
    self.dataset = datasets.LIBRITTS(root=root, download=download)
    self.batch_size = batch_size

    lang_map = get_lang_map(lang)
    self.preprocess_libtts = PreprocessLibriTTS(
        PreprocessingConfigUnivNet(lang_map.processing_lang_type),
    )

__len__()

Returns the number of samples in the dataset.

Returns int: Number of samples in the dataset.

Source code in training/datasets/libritts_dataset_vocoder.py
def __len__(self) -> int:
    r"""Returns the number of samples in the dataset.

    Returns
        int: Number of samples in the dataset.
    """
    return len(self.dataset)

collate_fn(data)

Collates a batch of data samples.

Parameters:

Name Type Description Default
data List

A list of data samples.

required

Returns:

Name Type Description
List List

A list of reprocessed data batches.

Source code in training/datasets/libritts_dataset_vocoder.py
def collate_fn(self, data: List) -> List:
    r"""Collates a batch of data samples.

    Args:
        data (List): A list of data samples.

    Returns:
        List: A list of reprocessed data batches.
    """
    data_size = len(data)

    idxs = list(range(data_size))

    # Initialize empty lists to store extracted values
    empty_lists: List[List] = [[] for _ in range(4)]
    (
        mels,
        mel_lens,
        audios,
        speaker_ids,
    ) = empty_lists

    # Extract fields from data dictionary and populate the lists
    for idx in idxs:
        data_entry = data[idx]

        mels.append(data_entry["mel"])
        mel_lens.append(data_entry["mel"].shape[1])
        audios.append(data_entry["audio"])
        speaker_ids.append(data_entry["speaker_id"])

    mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
    mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
    audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
    speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

    return [
        mels,
        mel_lens,
        audios,
        speaker_ids,
    ]