from __future__ import annotations
''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, Optional, Union, Tuple, List
import multiprocessing as mp
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial
import time
import numpy as np
from causallearn.search.ConstraintBased.PC import pc as _pc
from causallearn.search.ConstraintBased.FCI import fci as _fci
from causallearn.search.ConstraintBased.CDNOD import cdnod as _cdnod
from causallearn.search.ScoreBased.GES import ges as _ges
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
from causallearn.search.FCMBased import lingam as _lingam
from causallearn.search.FCMBased.ANM.ANM import ANM
from causallearn.utils.cit import fisherz, chisq, gsq, mv_fisherz, kci, CIT
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
# callback function: (current step, total step, current graph, whether use probability) -> Any
_CALLBACK_TYPE = Callable[[int, int, np.ndarray, bool], Any]
_CIT_METHODS = {
    "fisherz": fisherz,
    "chisq": chisq,
    "gsq": gsq,
    "mv_fisherz": mv_fisherz,
    "kci": kci
}
""" causal discovery methods """
# constraint-based
[docs]
def pc(
    data: np.ndarray,
    indep: str = 'fisherz',
    thresh: float = 0.05,
    bg_rules: Optional[BackgroundKnowledge] = None,
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    # start
    if callback:
        callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False)
    cg = _pc(data, thresh, indep, True, 0, -1, background_knowledge=bg_rules)
    # end
    if callback:
        callback(1, 1, cg.G.graph, False)
    return cg.G.graph, False 
[docs]
def fci(
    data: np.ndarray,
    indep: str = 'fisherz',
    thresh: float = 0.05,
    bg_rules: Union[BackgroundKnowledge, None] = None,
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    # start
    if callback:
        callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False)
    G, _ = _fci(data, indep, thresh, verbose=False, background_knowledge=bg_rules)
    # end
    if callback:
        callback(1, 1, G.graph, False)
    return G.graph, False 
[docs]
def inter_cit(
    data: np.ndarray,
    indep: str = "fisherz",
    inter_classes: Iterable[Iterable[Iterable[int]]] = [],
    in_parallel: bool = True,
    parallel_limit: int = 5,
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    """ use cit to discover the relations of variables
        inter different classes (indicated by indices)
    """
    n = data.shape[1]
    # initialize p-values graph and cit method
    p_values = np.zeros((n, n))
    cit = CIT(data, method=indep, **kwargs)
    # completed task counter
    completed = mp.Value("d", 0)
    # task parameters: cit algorithm, result matrix, input index, output index,
    # condition indices, callback, completed counter, total task number, (, lock)
    task_params = []
    for inter_pair in inter_classes:
        assert len(inter_pair) == 2, "Can only test relation between two classes"
        input_indices, output_indices = inter_pair[0], inter_pair[1]
        for idx in range(len(input_indices)):
            i = input_indices[idx]
            for o in output_indices:
                # params: cit algorithm, result matrix, input index,
                # output index, condition indices, completed counter
                task_params.append({
                    "cit": cit,
                    "mat": p_values,
                    "in_idx": i,
                    "out_idx": o,
                    "c_indices": input_indices[:idx]+input_indices[idx+1:],
                    "counter": completed,
                    "callback": callback,
                })
        # add total task number into the params
        for tp in task_params:
            tp["tot"] = len(task_params)
    # start
    if callback:
        callback(int(completed.value), len(task_params), p_values, True)
    # task function
    def do_cit_once(params):
        cit_alg, mat, x, y, C, cb, cnt, tot = itemgetter(
            "cit", "mat", "in_idx", "out_idx", "c_indices", "callback", "counter", "tot")(params)
        l = params["lock"] if "lock" in params else None
        mat[x, y] = 1 - cit_alg(x, y, condition_set=C)
        # modify number of completed tasks
        if l:
            l.acquire()
        cnt.value += 1
        if l:
            l.release()
        # callback
        if cb:
            callback(int(cnt.value), tot, mat, True)
    if in_parallel:
        # parallel running
        lock = mp.Lock()
        for tp in task_params:
            tp["lock"] = lock
        pool = ThreadPool(parallel_limit)
        pool.map(do_cit_once, task_params)
        pool.close()
        pool.join()
    else:
        # sequential running
        for i, tp in enumerate(task_params):
            do_cit_once(tp)
    return p_values, True 
# FCM-based
[docs]
def lingam(
    data: np.ndarray,
    ver: str = 'ica',
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    model = _lingam.ICALiNGAM()
    if ver == 'direct':
        model = _lingam.DirectLiNGAM()
    elif ver == 'var':
        model = _lingam.VARLiNGAM()
    elif ver == 'rcd':
        model = _lingam.RCD()
    # start
    if callback:
        callback(0, 1, np.zeros(data.shape[-1], data.shape[-1]), True)
    model.fit(data)
    if ver == 'var':
        adj_matrix = model.adjacency_matrices_[0]
    else:
        adj_matrix = model.adjacency_matrix_
    # end
    if callback:
        callback(1, 1, np.abs(adj_matrix).T, True)
    return np.abs(adj_matrix).T, True 
[docs]
def anm(
    data: np.ndarray,
    kernelX: str = "Gaussian",
    kernelY: str = "Gaussian",
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    model = ANM(kernelX, kernelY)
    graph = np.zeros((data.shape[-1], data.shape[-1]))
    # start
    if callback:
        callback(0, np.prod(graph.shape), graph, True)
    # orient edge by edge
    for i in range(graph.shape[0]):
        for j in range(graph.shape[1]):
            p_value_f, p_value_b = model.cause_or_effect(data[:, i:i+1], data[:, j:j+1])
            graph[i, j], graph[j, i] = p_value_f, p_value_b
            # callback
            if callback:
                callback(i * graph.shape[1] + j + 1, np.prod(graph.shape), graph, True)
    return graph, True 
# score-based
_AVAILABLE_SCORE_FNS = [
    "BIC", "BDeu", "CV_general", "marginal_general", "CV_multi", "marginal_multi",
]
[docs]
def ges(
    data: np.ndarray,
    score_func: str = 'BIC',
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    assert score_func in _AVAILABLE_SCORE_FNS, \
        
"Do not support score function '{}' (available score functions: '{}')".format(
            score_func, "', '".join(_AVAILABLE_SCORE_FNS))
    load_func = f"local_score_{score_func}"
    # start
    if callback:
        callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False)
    record = _ges(data, load_func)
    # end
    if callback:
        callback(1, 1, record['G'].graph, False)
    return record['G'].graph, False 
_AVAILABLE_SEARCH_METHODS = [
    "astar", "dp",
]
[docs]
def exact_search(
    data: np.ndarray,
    method: str = 'astar',
    callback: Optional[_CALLBACK_TYPE] = None,
    **kwargs
) -> Tuple[np.ndarray, bool]:
    assert method in _AVAILABLE_SEARCH_METHODS, \
        
"Do not support search method '{}' (available search methods: '{}')".format(
            method, "', '".join(_AVAILABLE_SEARCH_METHODS))
    # start
    if callback:
        callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False)
    dag, _ = bic_exact_search(data, search_method=method)
    # end
    if callback:
        callback(1, 1, dag, False)
    return dag, False 
[docs]
class Graph:
    """ Causal graph class """
    def __init__(
        self,
        graph: np.ndarray,
        is_real: bool = False,
        thresh_info: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        :param graph: ndarray, causal-learn style adjacency matrix,
            shape [state_dim + action_dim + state_dim, state_dim + action_dim + state_dim]
        :param is_real: whether the every element in graph is a real number
        :param thresh_info: information about the threshold of the graph, only used when the
            elements of the graph is real numbers
        """
        assert graph.shape[0] == graph.shape[1], "Graph required to be a square matrix"
        self._graph = graph.copy()
        self._is_real = is_real
        self._thresh_info = thresh_info
        self._adj_mat = self._format_graph(graph)
    def _format_graph(self, mat: np.ndarray) -> np.ndarray:
        """ format binary causal-learn style graph as adjacency matrix (DAG) """
        # causal-learn style binary matrix, -1 represents
        # the start of an edge, 1 represents the end of an edge
        mat = mat.copy()
        if not self._is_real:
            mat[np.arange(mat.shape[0]), np.arange(mat.shape[0])] = 0
            # i -> j
            single_direct = (mat == -1) & (mat.T == 1)
            # i - j or i <-> j
            bi_direct = ((mat == -1) & (mat.T == -1)) | ((mat == 1) & (mat == 1))
            # other area
            mat[(~single_direct) & (~bi_direct)] = 0
            # i -> j (m[i, j]=1, m[j, i]=0)
            mat[single_direct] = 1
            mat[single_direct.T] = 0
            # i - j or i <-> j (m[i, j] = m[j, i] = 1)
            mat[bi_direct] = 1
            mat[bi_direct.T] = 1
        return mat
    @property
    def graph(self):
        """ raw graph """
        return self._graph
    @property
    def thresh_info(self):
        """ information about threshold """
        return self._thresh_info
[docs]
    def get_adj_matrix(self):
        """ return transition graph [S+A+S, S] (binary or real) """
        return self._adj_mat 
[docs]
    def get_binary_adj_matrix(self, thresh=None):
        """ return binary transition graph (with threshold specified) """
        if self._is_real:
            if thresh is None:
                thresh = 0.
            return (self._adj_mat > thresh).astype(int)
        else:
            return self._adj_mat 
[docs]
    def get_binary_adj_matrix_by_sparsity(self, sparsity=None):
        """ return binary transition graph (with sparsity specified) """
        if self._is_real:
            thresh = 0.
            if sparsity is not None:
                assert 0 <= sparsity and sparsity <= 1
                flatten_mat = self._adj_mat.reshape(-1)
                last_ele = int(np.floor(len(flatten_mat) * sparsity))
                last_ele = len(flatten_mat)-1 if last_ele >= len(flatten_mat) else last_ele
                arg_sorted = np.argsort(flatten_mat)
                thresh = flatten_mat[arg_sorted[last_ele]]
            return self.get_binary_adj_matrix(thresh)
        return self._adj_mat 
 
[docs]
class TransitionGraph(Graph):
    """ RL transition graph class """
    def __init__(
        self,
        graph: np.ndarray,
        state_dim: int,
        action_dim: int,
        is_real: bool = False,
        thresh_info: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        :param graph: ndarray, causal-learn style adjacency matrix,
            shape [state_dim + action_dim + state_dim, state_dim + action_dim + state_dim]
        :param state_dim: int, the dimension of state variables
        :param action_dim: int, the dimension of action variables
        :param is_real: whether the every element in graph is a real number
        :param thresh_info: information about the threshold of the graph, only used when the
            elements of the graph is real numbers
        """
        self._state_dim = state_dim
        self._action_dim = action_dim
        super().__init__(graph, is_real, thresh_info)
    def _format_graph(self, mat: np.ndarray) -> np.ndarray:
        """ format binary graph [S+A+S, S+A+S] as transition graph [S+A+S, S] """
        mat = mat.copy()
        if not self._is_real:
            # causal-learn style binary matrix, -1 represents
            # the start of an edge, 1 represents the end of an edge
            # state, action -> next_state
            inter_mat = mat[
                :self._state_dim+self._action_dim, self._state_dim+self._action_dim:]
            inter_mat[inter_mat == -1] = 1
            # state -> action
            inter_mat = mat[:self._state_dim, self._state_dim:self._state_dim+self._action_dim]
            inter_mat[inter_mat == -1] = 1
            # next_state -> next_state
            next_state_dim = mat.shape[0] - self._state_dim - self._action_dim
            for i in range(next_state_dim):
                start_idx = self._state_dim + self._action_dim + i
                # no loop
                mat[start_idx, start_idx] = 0
                for j in range(i + 1, next_state_dim):
                    end_idx = self._state_dim + self._action_dim + j
                    # start -> end
                    if mat[start_idx, end_idx] == -1 \
                        
and mat[end_idx, start_idx] == 1:
                        mat[start_idx, end_idx] = 1
                        mat[end_idx, start_idx] = 0
                    # end -> start
                    elif mat[start_idx, end_idx] == 1 \
                        
and mat[end_idx, start_idx] == -1:
                        mat[start_idx, end_idx] = 0
                        mat[end_idx, start_idx] = 1
                    # start - end
                    elif mat[start_idx, end_idx] \
                        
== mat[end_idx, start_idx] == -1:
                        mat[start_idx, end_idx] = mat[end_idx, start_idx] = 1
                    # start <-> end
                    # TODO: handle unobserved confounder
                    elif mat[start_idx, end_idx] \
                        
== mat[end_idx, start_idx] == 1:
                        mat[start_idx, end_idx] = mat[end_idx, start_idx] = 1
        # no ... -> state
        mat[:, :self._state_dim] = 0.
        # no action | next state -> action
        mat[self._state_dim:, self._state_dim:self._state_dim+self._action_dim] = 0.
        return mat 
[docs]
class DiscoveryModule(ABC):
    """ Base class for causal discovery modules """
    def __init__(self, **kwargs) -> None:
        self._graph: Union[Graph, None] = None
[docs]
    @abstractmethod
    def fit(self, data: Any, **kwargs) -> DiscoveryModule:
        pass 
    @property
    def graph(self) -> Union[Graph, None]:
        return self._graph 
[docs]
class ClassicalDiscovery(DiscoveryModule):
    """ Classical causal discovery algorithms """
    CLASSICAL_ALGOS = {
        "pc": pc,
        "fci": fci,
        "lingam": lingam,
        "anm": anm,
        "ges": ges,
        "exact_search": exact_search,
    }
    # designed only for transition graph
    CLASSICAL_ALGOS_TRANSITION = {
        "inter_cit": inter_cit,
    }
    CLASSICAL_ALGOS_THRESH_INFO = {
        "anm": {
            "min": 0., "max": 1., "common": 0.5,
        },
        "inter_cit": {
            "min": 0., "max": 1., "common": 0.8,
        },
        "lingam": {
            "min": 0., "max": float("inf"), "common": 0.01,
        }
    }
    def __init__(
        self,
        alg: str = "inter_cit",
        alg_args: Dict[str, Any] = {"indep": "kci", "in_parallel": False},
        state_keys: Optional[List[str]] = ["obs"],
        action_keys: Optional[List[str]] = ["action"],
        next_state_keys: Optional[List[str]] = ["next_obs"],
        limit: Optional[int] = 100,
        use_residual: bool = True,
        **kwargs
    ) -> None:
        """
        :param alg: str, algorithm name, options include
            'pc': Peter-Clark algorithm,
            'fci': Fast Causal Inference,
            'lingam': Linear Non-Gaussian Model,
            'anm': Additive Nonlinear Model,
            'ges': Greedy Equivalence Search,
            'exact_search': Exact Search,
        :param alg_args: additional arguments for the algorithm used,
            pc:
                indep: str, conditional independence test used, including
                    'fisherz': Fisher's Z conditional independence test,
                    'chisq': Chi-squared conditional independence test,
                    'gsq': G-squared conditional independence test,
                    'kci': Kernel-based conditional independence test,
                    'mv_fisherz': Missing-value Fisher's Z conditional independence test,
                thresh: float, level of significance for conditional independence test
            fci: the same as pc
            cit:
                indep: the same as 'indep' in pc,
                in_parallel: whether run the algorithm in parallel,
                parallel_limit: limit of the number of workers
            lingam:
                ver: the version of linear non-Gaussian model to use, including
                    'ica': ICA-based LiNGAM,
                    'direct': DirectLiNGAM,
                    'var': VAR-LiNGAM,
                    'rcd': RCD (repetitive causal discovery)
            anm:
                kernelX: the kernel function for cause data, including
                    'Gaussian': Gaussian kernel,
                    'Polynomial': Polynomial kernel,
                    'Linear': Linear kernel,
                kernelY: the kernel function for effect data, options are the same as kernelX
            ges:
                score_func: score function used to score the graph, including
                    'BIC': BIC score,
                    'BDeu': BDeu score,
                    'CV_general': Generalized score with cross validation
                        for data with single-dimensional variables
                    'marginal_general': Generalized score with marginal likelihood
                        for data with single-dimensional variables
                    'CV_multi': Generalized score with cross validation
                        for data with multi-dimensional variables
                    'marginal_multi': Generalized score with marginal likelihood
                        for data with multi-dimensional variables
            exact_search:
                method: the search method used, including:
                    'dp': dynamic programming (DP),
                    'astar': A* search,
        :param state_keys: list[str] | None, specifying the keys of
            states in the input data dictionary (None indicating not using transition data)
        :param action_keys: list[str] | None, specifying the keys of
            actions in the input data dictionary (None indicating not using transition data)
        :param next_state_keys: list[str] | None, specifying the keys of
            next states in the input data dictionary (None indicating not using transition data)
        :param limit: int | None, limit for the number of data samples used
        :param residual: bool, whether use residual as next states (only for transition data)
        """
        super().__init__(**kwargs)
        assert alg in ClassicalDiscovery.CLASSICAL_ALGOS \
            
or alg in ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \
            
"Do not support algorithm {} (available: '{}')".format(
                alg, "', '".join(
                    list(ClassicalDiscovery.CLASSICAL_ALGOS.keys()) + list(ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION.keys())
                ))
        assert (state_keys is not None and action_keys is not None and next_state_keys is not None) \
                
or (state_keys is None and action_keys is None and next_state_keys is None), \
                
"state, action, next states keys should all be None or not None"
        self._support_transition = state_keys is not None
        if self._support_transition:
            assert len(state_keys) > 0 and len(action_keys) > 0 and len(next_state_keys) > 0, \
                
"state, action, next states keys can not be empty"
        self._alg_name = alg
        if alg in ClassicalDiscovery.CLASSICAL_ALGOS:
            self._alg = ClassicalDiscovery.CLASSICAL_ALGOS[alg]
        else:
            assert self._support_transition, f"{alg} only work for transition data"
            self._alg = ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION[alg]
        self._alg_args = alg_args
        self._state_keys = state_keys
        self._action_keys = action_keys
        self._next_state_keys = next_state_keys
        self._limit = limit
        self._use_residual = use_residual
    def _extract_data(
        self, data_dict: Dict[str, np.ndarray]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        states = itemgetter(*self._state_keys)(data_dict)
        actions = itemgetter(*self._action_keys)(data_dict)
        next_states = itemgetter(*self._next_state_keys)(data_dict)
        if isinstance(states, tuple):
            states = np.concatenate(states, axis=-1)
        if isinstance(actions, tuple):
            actions = np.concatenate(actions, axis=-1)
        if isinstance(next_states, tuple):
            next_states = np.concatenate(next_states, axis=-1)
        return states, actions, next_states
    def _build_graph(
        self,
        graph_mat: np.ndarray,
        is_real: bool,
        state_dim: Optional[int] = None,
        action_dim: Optional[int] = None,
    ):
        """ build graph with graph adjacency matrix and whether the elements are real numbers
        :param graph_mat: ndarray, adjacency matrix of the graph
        :param is_real: whether the elements in the graph is real numbers
        :param state_dim: int or None, if not None, indicating the graph is a transition graph,
            data is shaped as [state_dim, action_dim, next_state_dim (state_dim)]
        :param action_dim: int or None, similar with state_dim
        """
        # information about the threshold of the edge existence (None if the graph is binary)
        thresh_info = None
        if self._alg_name in ClassicalDiscovery.CLASSICAL_ALGOS_THRESH_INFO:
            thresh_info = ClassicalDiscovery.CLASSICAL_ALGOS_THRESH_INFO[self._alg_name]
        if state_dim is not None and action_dim is not None:
            # build a transition graph
            self._graph = TransitionGraph(graph_mat, state_dim, action_dim,
                                          is_real, thresh_info)
        else:
            # build a regular graph
            self._graph = Graph(graph_mat, is_real, thresh_info)
    def _fit_transition(
        self,
        data: np.ndarray,
        state_dim: int,
        action_dim: int
    ) -> ClassicalDiscovery:
        """ fit the discovery module to transition data
        :param data: ndarray, transition data, aranged by states, actions, next_states
        :param state_dim: int, dimension of states
        :param action_dim: int, dimension of actions
        :return: the module itself
        """
        if self._alg_name in ClassicalDiscovery.CLASSICAL_ALGOS:
            # build general rl background knowledge rules
            obs_act_pattern = "^(%s)$" % "|".join(
                "X%d" % i for i in range(1, state_dim+action_dim+1))
            next_obs_pattern = "^(%s)$" % "|".join(
                "X%d" % i for i in range(state_dim+action_dim+1, data.shape[-1]+1))
            obs_pattern = "^(%s)$" % "|".join("X%d" % i for i in range(1, state_dim+1))
            act_pattern = "^(%s)$" % "|".join(
                "X%d" % i for i in range(state_dim+1, state_dim+action_dim+1))
            bg_rules = BackgroundKnowledge()
            # forbid next state -> state | action
            bg_rules.add_forbidden_by_pattern(next_obs_pattern,
                                              obs_act_pattern)
            # forbid action -> state
            bg_rules.add_forbidden_by_pattern(act_pattern, obs_pattern)
            # causal discovery algorithm, return a graph matrix
            # and whether the element of the matrix is real number
            graph, is_real = self._alg(data, bg_rules=bg_rules,
                                       **self._alg_args)
        else:
            # build general rl transition variable classes
            inter_classes = [
                # state -> action
                [
                    list(range(state_dim)),
                    list(range(state_dim, state_dim + action_dim))
                ],
                # state, action -> next state
                [
                    list(range(state_dim + action_dim)),
                    list(range(state_dim + action_dim, data.shape[1]))
                ]
            ]
            # causal discovery algorithm, return a graph matrix
            # and whether the element of the matrix is real number
            graph, is_real = self._alg(data, inter_classes=inter_classes, 
                                       **self._alg_args)
        # build transition graph
        self._build_graph(graph, is_real, state_dim, action_dim)
        return self
    def _fit_all(self, data: np.ndarray) -> ClassicalDiscovery:
        """ fit the discovery module to general data
        :param data: [num_samples x num_features]
        :return: the module itself
        """
        graph, is_real = self._alg(data, bg_rules=None, **self._alg_args)
        # build graph
        self._build_graph(graph, is_real)
        return self
[docs]
    def fit(
        self,
        data: Union[Dict[str, np.ndarray], np.ndarray],
        fit_transition: bool = True,
    ) -> ClassicalDiscovery:
        """ fit the discovery module to transition data or general data
        :param data: dict[str, ndarray] | ndarray,
            transition data dictionary or general data matrix
        :return: the module itself
        """
        if fit_transition:
            assert isinstance(data, dict), "need transition data format"
            assert self._support_transition, \
                
"fitting transition data needs specifying keys"
            states, actions, next_states = self._extract_data(data)
            state_dim, action_dim = states.shape[-1], actions.shape[-1]
            assert states.shape[0] == actions.shape[0] == next_states.shape[0], \
                
"Transition data shape mismatch"
            # prepare data
            if self._use_residual:
                assert next_states.shape[1] == states.shape[1]
                next_states -= states
            data = np.concatenate([states, actions, next_states], axis=-1)
        else:
            assert isinstance(data, np.ndarray), "need general data format"
            assert self._alg_name not in ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \
                
f"{self._alg_name} only work for transition data"
        # limited data samples
        limit = self._limit if self._limit is not None else data.shape[0]
        if limit <= data.shape[0]:
            indices = np.random.choice(data.shape[0],
                                       size=limit, replace=False)
            data = data[indices]
        # start fitting
        if fit_transition:
            return self._fit_transition(data, state_dim, action_dim)
        else:
            return self._fit_all(data) 
 
[docs]
class AsyncClassicalDiscovery(ClassicalDiscovery):
    """ Classical causal discovery algorithms (support asynchronous ver.) """
    def __init__(
        self,
        alg: str = "inter_cit",
        alg_args: Dict[str, Any] = { "indep": "kci","in_parallel": False },
        state_keys: Optional[List[str]] = ["obs"],
        action_keys: Optional[List[str]] = ["action"],
        next_state_keys: Optional[List[str]] = ["next_obs"],
        limit: Optional[int] = 100,
        use_residual: bool = True,
        callback: Optional[_CALLBACK_TYPE] = None,
        **kwargs
    ) -> None:
        super().__init__(
            alg, alg_args, state_keys, action_keys,
            next_state_keys, limit, use_residual, **kwargs)
        self._custom_callback = callback
        # init progress management
        self.cur_step = None
        self.tot_step = None
        self.start_time = None
        self.remaining_time = None
        self.is_running = False
    def _before_start(self):
        """ recording done before discovery starts """
        self.cur_step = 0
        self.tot_step = 0
        self.start_time = time.time()
        self.remaining_time = float("inf")
        self.is_running = True
    def _after_end(self):
        """ recording done after discovery ends """
        self.is_running = False
    def _callback(
        self,
        cur_step: int,
        tot_step: int,
        graph: np.ndarray,
        is_real: bool,
        state_dim: Optional[int] = None,
        action_dim: Optional[int] = None,
    ):
        # record progress
        self.cur_step = cur_step
        self.tot_step = tot_step
        self.elapsed_time = time.time() - self.start_time
        if self.cur_step != 0:
            self.remaining_time = self.elapsed_time * (tot_step / cur_step - 1)
        # record graph
        self._build_graph(graph, is_real, state_dim, action_dim)
        # custom callback
        if self._custom_callback:
            self._custom_callback(cur_step, tot_step, graph, is_real)
[docs]
    def set_callback(self, callback: _CALLBACK_TYPE):
        """ set custom callback function """
        self._custom_callback = callback 
[docs]
    def fit(
        self,
        data: Union[Dict[str, np.ndarray], np.ndarray],
        fit_transition: bool = True
    ) -> ClassicalDiscovery:
        if fit_transition:
            assert isinstance(data, dict), "need transition data format"
            assert self._support_transition, \
                
"fitting transition data needs specifying keys"
            states, actions, next_states = self._extract_data(data)
            state_dim, action_dim = states.shape[-1], actions.shape[-1]
            assert states.shape[0] == actions.shape[0] == next_states.shape[0], \
                
"Transition data shape mismatch"
            # prepare data
            if self._use_residual:
                assert next_states.shape[1] == states.shape[1]
                next_states -= states
            data = np.concatenate([states, actions, next_states], axis=-1)
        else:
            assert isinstance(data, np.ndarray), "need general data format"
            assert self._alg_name not in AsyncClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \
                
f"{self._alg_name} only work for transition data"
            state_dim = action_dim = None
        # limited data samples
        limit = self._limit if self._limit is not None else data.shape[0]
        if limit <= data.shape[0]:
            indices = np.random.choice(data.shape[0], size=limit, replace=False)
            data = data[indices]
        # add callback into the algorithm arguments
        self._alg_args["callback"] = partial(
            self._callback, state_dim=state_dim, action_dim=action_dim)
        # start fitting
        self._before_start()
        try:
            if fit_transition:
                return self._fit_transition(data, state_dim, action_dim)
            else:
                return self._fit_all(data)
        finally:
            # finished
            self._after_end() 
 
if __name__ == "__main__":
    np.random.seed(0)
    """ basic case for transition data """
    states = np.random.normal(0, 1, size=10000).reshape(-1, 2)
    actions = np.tanh(states[:, 0:1])
    delta_states = 2*states + actions + np.random.uniform(-0.1, 0.1, size=10000).reshape(-1, 2)
    next_states = delta_states + states
    # default keys
    data_dict = {"obs": states, "action": actions, "next_obs": next_states}
    # default keys are not None, the module supports transition data
    discover_module = ClassicalDiscovery()
    discover_module.fit(data_dict, fit_transition=True)
    # threshshold information (maybe None)
    print("graph threshold information\n", discover_module.graph.thresh_info)
    # for debug or private usage
    print("algorithm output graph\n", discover_module.graph.graph)
    # raw adjacent matrix
    print("transition graph\n", discover_module.graph.get_adj_matrix())
    # get binary matrix by indicating threshold (recommended)
    # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1])
    print(
        "binary transition graph by threshold\n",
        discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"])
    )
    # get binary matrix by indicating sparsity (need prior knowledge)
    print(
        "binary transition graph by sparsity\n",
        discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75)
    )
    """ multi-keys """
    states2 = np.random.uniform(-1, 1, size=5000).reshape(-1, 1)
    delta_states2 = states2 - actions \
        + np.random.uniform(-0.1, 0.1, size=5000).reshape(-1, 1)
    next_states2 = delta_states2 + states2
    data_dict2 = {
        "obs1": states, "obs2": states2,
        "action": actions,
        "next_obs1": next_states, "next_obs2": next_states2
    }
    # keys are not None, the module supports transition data
    discover_module = ClassicalDiscovery(
        state_keys=["obs1", "obs2"],
        action_keys=["action"],
        next_state_keys=["next_obs1", "next_obs2"]
    )
    discover_module.fit(data_dict2, fit_transition=True)
    # threshshold information (maybe None)
    print("graph threshold information\n", discover_module.graph.thresh_info)
    # for debug or private usage
    print("algorithm output graph\n", discover_module.graph.graph)
    # raw adjacent matrix
    print("transition graph\n", discover_module.graph.get_adj_matrix())
    # get binary matrix by indicating threshold (recommended)
    # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1])
    print(
        "binary transition graph by threshold\n",
        discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"])
    )
    # get binary matrix by indicating sparsity (need prior knowledge)
    print(
        "binary transition graph by sparsity\n",
        discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75)
    )
    """ basic case for general data (default algorithm does not support general data) """
    discover_module = ClassicalDiscovery(
        alg="lingam",
        alg_args={"ver": "direct"},
        limit=None,
    )
    s = np.random.uniform(-1, 1, 5000).reshape(-1, 1)
    a = 1.1 * s + np.random.uniform(-0.1, 0.1, size=5000).reshape(-1, 1)
    s_ = s + a + np.random.uniform(-0.01, 0.01, size=5000).reshape(-1, 1)
    data = np.concatenate((s, a, s_), axis=-1)
    discover_module.fit(data, fit_transition=False)
    # threshshold information (maybe None)
    print("graph threshold information\n", discover_module.graph.thresh_info)
    # for debug or private usage
    print("algorithm output graph\n", discover_module.graph.graph)
    # raw adjacent matrix
    print("transition graph\n", discover_module.graph.get_adj_matrix())
    # get binary matrix by indicating threshold (recommended)
    # 0.01 is a proper threshold for method "lingam" (>=0, not upper bounded)
    print(
        "binary transition graph by threshold\n",
        discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"])
    )
    # get binary matrix by indicating sparsity (need prior knowledge)
    print(
        "binary transition graph by sparsity\n",
        discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75)
    )
    """ use more data to acquire a more accurate result """
    discover_module = ClassicalDiscovery(
        # run in parallel
        alg_args={"indep": "kci", "in_parallel": True, "parallel_limit": 5},
        state_keys=["obs1", "obs2"],
        action_keys=["action"],
        next_state_keys=["next_obs1", "next_obs2"],
        limit=1000,
    )
    discover_module.fit(data_dict2, fit_transition=True)
    # threshshold information (maybe None)
    print("graph threshold information\n", discover_module.graph.thresh_info)
    # for debug or private usage
    print("algorithm output graph\n", discover_module.graph.graph)
    # raw adjacent matrix
    print("transition graph\n", discover_module.graph.get_adj_matrix())
    # get binary matrix by indicating threshold (recommended)
    # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1])
    print(
        "binary transition graph by threshold\n",
        discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"])
    )
    # get binary matrix by indicating sparsity (need prior knowledge)
    print(
        "binary transition graph by sparsity\n",
        discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75)
    )
    """ use callback to informing progress """
    async_discover_module = AsyncClassicalDiscovery(
        state_keys=["obs1", "obs2"],
        action_keys=["action"],
        next_state_keys=["next_obs1", "next_obs2"],
        limit=3000,
        # callback can be set when constructing
        callback=None,
    )
    # callback type should be `_CALLBACK_TYPE`, see related comments
    def callback(cur_step: int, tot_step: int, cur_raw_graph: np.ndarray, is_real: bool):
        print(f"current step {cur_step}, total step {tot_step}, current raw graph {cur_raw_graph},"
            f" whether the elements of the graph is real number {is_real}")
        # can also get progress information from the module
        print(f"current step {async_discover_module.cur_step},"
            f" total step {async_discover_module.tot_step}")
        print(f"start time {async_discover_module.start_time}")
        print(f"time elapsed {async_discover_module.elapsed_time}")
        print(f"is running {async_discover_module.is_running}")
        print(f"estimated remaining time {async_discover_module.remaining_time}")
        print("currnet graph {}".format(
            async_discover_module.graph.get_binary_adj_matrix(
                async_discover_module.graph.thresh_info["common"])))
        print("=" * 20)
    # callback can also be set after constructed
    async_discover_module.set_callback(callback)
    # start fitting
    async_discover_module.fit(data_dict2, fit_transition=True)
    print(f"is running {async_discover_module.is_running}")
    # threshshold information (maybe None)
    print("graph threshold information\n", async_discover_module.graph.thresh_info)
    # for debug or private usage
    print("algorithm output graph\n", async_discover_module.graph.graph)
    # raw adjacent matrix
    print("transition graph\n", async_discover_module.graph.get_adj_matrix())
    # get binary matrix by indicating threshold (recommended)
    # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1])
    print(
        "binary transition graph by threshold\n",
        async_discover_module.graph.get_binary_adj_matrix(
            async_discover_module.graph.thresh_info["common"])
    )
    # get binary matrix by indicating sparsity (need prior knowledge)
    print(
        "binary transition graph by sparsity\n",
        async_discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75)
    )