# -*- coding: utf-8 -*-
# @Time : 2023/2/9 10:01
# @Author : ctt
# @File : FewMultiLabelDataLoader
# @Project : platform_zzsn
import os
from paddlenlp.datasets import load_dataset


class FewMultiLabelDataLoader():

    def __init__(self, config_path):
        super().__init__(config_path)

    @staticmethod
    def load_local_dataset(data_path, splits, label_list):
        """
        Load dataset for multi-label classification from files, where
        there is one example per line. Text and label are seperated
        by '\t', and multiple labels are delimited by ','.

        Args:
            data_path (str):
                Path to the dataset directory, including label.txt, train.txt,
                dev.txt (and data.txt).
            splits (list):
                Which file(s) to load, such as ['train', 'dev', 'test'].
            label_list (dict):
                The dictionary that maps labels to indeces.
        """

        def _reader(data_file, label_list):
            with open(data_file, "r", encoding="utf-8") as fp:
                for idx, line in enumerate(fp):
                    data = line.strip().split("\t")
                    if len(data) == 1:
                        yield {"text_a": data[0]}
                    else:
                        text, label = data
                        label = label.strip().split(",")
                        label = [float(1) if x in label else float(0) for x in label_list]
                        yield {"text_a": text, "labels": label}

        split_map = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"}
        datasets = []
        for split in splits:
            data_file = os.path.join(data_path, split_map[split])
            datasets.append(load_dataset(_reader, data_file=data_file, label_list=label_list, lazy=False))
        return datasets
