''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
import yaml
import numpy as np
from copy import deepcopy
from revive.utils.common_utils import load_data
from revive.utils.causal_discovery_utils import ClassicalDiscovery
[docs]
class CausalGraph:
    def __init__(self, data_file, yaml_file, seed=1024):
        """Finding causal graphs using causal discovery algorithms
        Args:
            data_file : *.npz file path or *.h5 file path
            yaml_fle : *.yaml file path
            seed : random seed
        """
        with open(yaml_file, 'r', encoding='UTF-8') as f:
            self.raw_config = yaml.load(f, Loader=yaml.FullLoader)
        data = load_data(data_file)
        self.seed(seed)
        assert "obs" in data.keys()
        assert "action" in data.keys()
        if "next_obs" not in data.keys():
            end_indexes = data['index'].astype(int)
            start_indexes = np.concatenate([np.array([0]), end_indexes[:-1]])
            curr_obs = []
            next_obs = []
            action = []
            for start, end in zip(start_indexes, end_indexes):
                curr_obs.append(data["obs"][start:end-1])
                next_obs.append(data["obs"][start+1:end])
                action.append(data["action"][start:end-1])
            self.obs = np.concatenate(curr_obs, axis=0)
            self.next_obs = np.concatenate(next_obs, axis=0)
            self.action = np.concatenate(action, axis=0)
        else:
            self.obs = data["obs"]
            self.next_obs = data["next_obs"]
            self.action = data["action"]
        self.data = data
        self.obs_dims = self.obs.shape[1]
        self.action_dims = self.action.shape[1]
        self.algo_cls = ClassicalDiscovery
[docs]
    def seed(self, seed):
        np.random.seed(seed) 
[docs]
    def fit(self, sample_size=-1):
        """Fit using causal discovery algorithms
        Args:
            sample_size : Limit the number of samples used.
                          The more the number of samples,
                          the longer the training time it takes,
                          -1 means use all samples.
        """
        if sample_size == -1:
            self.algo = self.algo_cls()
        elif sample_size >= 1:
            self.algo = self.algo_cls(limit=int(sample_size))
        else:
            raise ValueError(f"The sample_size should be an integer greater \
                             than or equal to -1. \
                             It should not be {sample_size}")
        data = {"obs": self.obs,
                "action": self.action,
                "next_obs": self.next_obs}
        self.algo.fit(deepcopy(data), fit_transition=True) 
    @property
    def causal_graph(self):
        return self.algo.graph.graph
    @property
    def causal_graph_threshold(self):
        return self.algo.graph.thresh_info
    @property
    def causal_matrix(self):
        return self.algo.graph.get_adj_matrix()
[docs]
    def causal_binary_matrix(self, threshold=None):
        """Convert the causality matrix to a
        two-dimensional connectivity diagram
        Args:
            threshold : Causal truncation threshold,
                        only greater than or equal to
                        this value is considered to have
                        a causal relationship,
                        using 1 means there is a
                        causal relationship,
                        0 means there is no
                        causal relationship.
        Return:
            causal_binary_matrix:  [[0, 0, 1, 1],
                                    [0, 0, 1, 1],
                                    [0, 0, 1, 1],
                                    [0, 0, 1, 1]]
        """
        if threshold is None:
            threshold = 0.5
        threshold = self.threshold_transform(threshold)
        return self.algo.graph.get_binary_adj_matrix(threshold) 
[docs]
    def decision_graph(self, npz_file, yaml_file, threshold=None):
        """Generate decision flow graph for use by REVIVE SDK
        Args:
            yaml_file : The address where the newly
                        generated yaml file is saved.
            npz_file : The address where the newly 
                       generated npz file is saved.
            threshold : Causal truncation threshold,
                        only greater than or equal to
                        this value is considered to have
                        a causal relationship,
                        using 1 means there is a causal
                        relationship, 0 means there is no
                        causal relationship.
        """
        causal_binary_matrix = self.causal_binary_matrix(threshold)
        obs_nodes = {"action_realated": [],
                     "transition_related": [],
                     "useless": []}
        # Get the action related obs features
        # Get the next_obs related obs features
        # Get the useless obs features
        raw_config = deepcopy(self.raw_config)
        for obs_dim in range(self.obs_dims):
            if np.sum(causal_binary_matrix[obs_dim][self.obs_dims:self.obs_dims+self.action_dims]) > 0:
                obs_nodes["action_realated"].append(obs_dim)
        for obs_dim in range(self.obs_dims):
            if obs_dim in obs_nodes["action_realated"]:
                continue
            if np.sum(causal_binary_matrix[obs_dim][[_dim+self.obs_dims+self.action_dims for _dim in obs_nodes["action_realated"]]]) > 0:
                obs_nodes["transition_related"].append(obs_dim)
            else:
                obs_nodes["useless"].append(obs_dim)
        # TODO: Get the relation between
        # "action_realated" and "next_obs_related"
        # Generate decision graph
        data = deepcopy(self.data)
        obs = data.pop("obs")
        if len(obs_nodes["transition_related"]) > 0:
            # graph
            raw_config["metadata"]["graph"] = {'action': ['action_realated_obs'],
                                               'next_action_realated_obs': ['action_realated_obs',
                                                                            'transition_related_obs',
                                                                            'action'],
                                                'next_transition_related_obs': ['action_realated_obs',
                                                                                'transition_related_obs',
                                                                                'action']}
            # column
            obs_columns = [column for column in raw_config["metadata"]["columns"] if list(column.values())[0]["dim"] == "obs"]
            for obs_dim in obs_nodes["action_realated"]:
                obs_columns[obs_dim][list(obs_columns[obs_dim].keys())[0]]["dim"] = "action_realated_obs"
            for obs_dim in obs_nodes["transition_related"]:
                obs_columns[obs_dim][list(obs_columns[obs_dim].keys())[0]]["dim"] = "transition_related_obs"
            data["action_realated_obs"] = obs[:, obs_nodes["action_realated"]]
            data["transition_related_obs"] = obs[:, obs_nodes["transition_related"]]
        else:
            obs_columns = [column for column in raw_config["metadata"]["columns"] if list(column.values())[0]["dim"] == "obs"]
            for obs_dim in obs_nodes["action_realated"]:
                obs_columns[obs_dim][list(obs_columns[obs_dim].keys())[0]]["dim"] = "obs"
            data["obs"] = obs[:, obs_nodes["action_realated"]]
            # data["next_obs"] = next_obs[:,obs_nodes["action_realated"]]
            raw_config["metadata"]["graph"] = {'action': ['obs'],
                                               'next_obs': ['obs', 'action']}
        for obs_dim in obs_nodes["useless"]:
            obs_columns[obs_dim][list(obs_columns[obs_dim].keys())[0]]["dim"] = "useless_obs"
        with open(yaml_file, 'w') as f:
            yaml.dump(raw_config, f)
        np.savez_compressed(npz_file, **data)