"""
Methods to visualize and approximate big transition matrices arising from 
state-rich Markov chain models, e.g. in ecology and evolution.

Licence:
----------
authored 2014 by Cedric Midoux*, Valentin Bahier*, Katja Reichel* and Solenn Stoeckel*
*Institut National de la Recherche Agronomique (INRA)
This module is published under the GNU general public license: 
<http://www.gnu.org/licenses/gpl-2.0.html> .


Scientific citation:
--------------------
K. Reichel, V. Bahier, C. Midoux, N. Parisey, J.-P. Masson, S. Stoeckel (2015): 
"Interpretation and approximation tools for big, dense Markov chain 
transition matrices in population genetics"
tba.


Notes:
------
This module includes supplementary functions only and produces no output by 
itself. For a test run, use the example provided along with the source code.
The use of the * operator for importing this module is discouraged.


Version:
--------
This is version 2.0 [08/2015]. 
The "t_spec" option was added to the "networkplot" function, the "mp_path" 
option in the the "networkplot" function corrected, and the "testvectors" 
function corrected.


Functions:
----------
loadmatrix : loads a matrix from a text file (.txt, .csv etc.)

histogrid : generates a 2D matrix histogram

histo3D : generates a 3D matrix histogram

networkplot : generates a network with various optional visual statistics 
    based on the matrix

percolation : auxillary function for networkplot

eigenone : calculates the eigenvector to the dominant eigenvalue (usualy 
    equal to one) of the matrix

appromatrix : approximates matrix by introducing zeros, optional output in 
    a sparse format

testvectors : performs a G-test on two vectors


Requires:
---------
Sys, Numpy, Scipy, NetworkX, Matplotlib

"""
#!python3

try: 
    import numpy as np
    import scipy.sparse as sp
    import networkx as nx
    import matplotlib.pyplot as plt

except ImportError:
    raise ImportError("This module requires numpy, scipy, networkx and matplotlib.")
    

def loadmatrix(filepath=None, loadnames=False, comments=None, verbose=False):
    """
    Ask for a filepath and import matrix from there.
    
    Parameters:
    -----------
    filepath : string 
        Optional, gives the path (incl. filename) of a tab-delimited 
        .txt-like matrix file;has to be entered in a prompt if not provided 
        in the function call.
    
    loadnames : boolean
        Optional, wether or not to read in statenames from the first line of 
        the matrix file
    
    comments : string
        Optional, set a character which turns the following text into a 
        comment; text lines marked as comment will not be imported
    
    verbose : boolean
        If True, print out status messages & intermediate results
    
    Returns:
    --------
    matrix : a 2d numpy array with datatype float 
        If the upper left cell of the input file holds "#", row and column 
        names are cut off automatically.
    
    If loadnames is True, matrix and names are returned as a tuple:
    
    matrix : a 2d numpy array with datatype float 
        Without row and column names.
    
    statenames : a list
        Reads out the column names in the input file.
    
    """
    if filepath is None:
        if loadnames is False:
            filepath = input("Please provide a filepath: ")
            namein = input("Import statenames (y/N)? ")
            if namein in ["Y", "y", "T", "t"]:
                loadnames = True
        else:    
            filepath = input("Please provide a filepath:")
    
    try:
        if verbose is True:        
            print("Loading matrix from " + filepath + " ...")
        matrix = np.loadtxt(filepath, dtype = str, comments = comments)
        if verbose is True:        
            print("Upper left corner of the imported file:")
            print(matrix[0:5, 0:5])
            
    except:
        print("Bad filepath!")
        raise

    if loadnames is True:
        statenames = list(matrix[1:,0])
        if verbose is True:        
            print("Got %s statenames." %(len(statenames)))
        try:
            matrix = matrix[1:,1:].astype(np.float_)
        except:
            print("Bad data in matrix!")
            print(matrix)
            raise
    else:    
        try:
            matrix = matrix.astype(np.float_)
        except:
            try:
                matrix = matrix[1:,1:].astype(np.float_)
            except:
                print("Bad data in matrix!")
                print(matrix)
                raise
    
    if verbose is True:    
        print("Matrix size: ", matrix.shape)
    if loadnames is True: 
        return (matrix, statenames)
    else:
        return matrix


def histogrid(matrix, visualize=False, axes=None, scaling=None, 
    statenames=None, *args, **kwargs):
    """ 
    Create a color coded 2d histogram of a 2d numpy array. 
    
    Parameters:
    -----------
    matrix : a 2d numpy array holding the transition matrix

    visualize: boolean
        If True, the resulting diagram will be shown directly.
    
    axes : an axes object to draw to; None means drawing to current axes.
    
    scaling : a scaling applied to all values in the matrix
    ==========    ===============    =========================================
    value         range              description
    ==========    ===============    =========================================
    None          (0,1)              no scaling, values are taken as they are
    "log10"       (-inf, 0)          log10 scaling on the values
    "-log10"      (0, +inf)          negative log10 scaling on the values
    "logit"       (-inf, +inf)       logit scaling on the values
    ==========    ===============    =========================================
    
    statenames : labels for the states
    ==========    ===============    =========================================
    value         range              description
    ==========    ===============    =========================================
    None          (0, n-1)           python array indexes
    "1N"          (1, n)             count index of each state 
    [...]         list of stings     empty list : print no statenames
    ==========    ===============    =========================================
    
    Plus any arguments the matplotlib.pyplot.imshow() function would take.
    
    Returns:
    --------
    axes : matplotlib axes object
        The default colormap is greyscale with increasing luminosity - see 
        comments at matplotlib.pyplot.colormaps() for more information about 
        using color in scientific illustrations.
    
    """
    if sp.issparse(matrix):
        matrix = matrix.toarray()
    
    if axes is None:
        axes = plt.gca()

    
    if statenames is None:
        pass
    elif statenames == "1N" :
        statenames = list("%s" %(i) for i in range(1, matrix.shape[0]+1))
    elif isinstance(statenames, list):
        pass
    else:
        statenames = None
        print("Statenames incorrectly specified, using default instead.")
    
    if statenames is not None:
        axes.set_xticks(list(i  for i in range(matrix.shape[0])))
        axes.set_yticks(list(i  for i in range(matrix.shape[1])))
        axes.set_xticklabels(statenames)
        axes.set_yticklabels(statenames)
        
    if scaling == "log10":
        matrix = np.log10(matrix)
    elif scaling == "-log10":
        matrix = -np.log10(matrix)
    elif scaling == "logit":
        matrix = np.log10(matrix/(1-matrix))
    else:
        pass
    
    if "interpolation" not in kwargs:
        kwargs["interpolation"] = "none"
    
    if "origin" not in kwargs:
        kwargs["origin"] = "upper"
    
    if "aspect" not in kwargs:
        kwargs["aspect"] = "equal"
    
    if "cmap" not in kwargs:
        kwargs["cmap"] = "gray_r"
    
    axes.xaxis.set_ticks_position('top')
    
    axes = plt.imshow(matrix, *args, **kwargs)

    if visualize is True:
        plt.show()
    
    return axes

def histo3d(matrix, visualize=False, axes=None, scaling=None, statenames=None, spacing=0.1, angle=[20, 12], *args, **kwargs):
    """ 
    Create a 3d histogram of a 2d numpy array. 
    
    Parameters:
    -----------
    matrix : a 2d numpy array holding the transition matrix
    
    visualize: boolean
        If True, the resulting diagram will be shown directly.
    
    axes : an axes object to draw to; None means drawing to current axes.
    
    scaling : a scaling applied to all values in the matrix
    ==========    ===============    =========================================
    value         range              description
    ==========    ===============    =========================================
    None          (0,1)              no scaling, values are taken as they are
    "log10"       (-inf, 0)          log10 scaling on the values
    "-log10"      (0, +inf)          negative log10 scaling on the values
    "logit"       (-inf, +inf)       logit scaling on the values
    ==========    ===============    =========================================
    
    statenames : labels for the states
    ==========    ===============    =========================================
    value         range              description
    ==========    ===============    =========================================
    None          (0, n-1)           python array indexes
    "1N"          (1, n)             count index of each state 
    [...]         list of stings     empty list : print no statenames
    ==========    ===============    =========================================
    
    spacing : float, half distance between bars
    
    angle : tuple of numbers, viewing angle (azimut, elevation)
    
    Plus any arguments which can be used on Axes3D bar3d objects.
    
    Returns:
    --------
    axes : matplotlib 3d axes object
        Matrix columns are displayed on the y-axis, rows on the x-axis.
        
    Warning:
    --------
    As per version 1.3.1, the 3D plotting function in matplotlib has a bug 
        which results in an incorrect display of the bars at certain viewing 
        angles (wrong plotting order). This also produces a runtime warning 
        ("invalid value encountered in divide for n in normals"). Please 
        ignore the warning and use the "angle" argument to find a correct 
        display angle.
    
    """
    try:
        from mpl_toolkits.mplot3d import Axes3D
    except ImportError:
        raise ImportError("Matplotlib 3D module unavailable.")
    except RuntimeError:
        print("Matplotlib unable to open display.")
        raise
 
    if sp.issparse(matrix):
        matrix = matrix.toarray()
        
    if axes is None:
        axes = plt.gca(projection="3d")   
   
    if statenames is None:
        pass
    elif statenames == "1N" :
        statenames = list("%s" %(i) for i in range(1, matrix.shape[0]+1))
    elif isinstance(statenames, list):
        pass
    else:
        statenames = None
        print("Statenames incorrectly specified, using default instead.")
    
    if statenames is not None:
        axes.set_xticks(list(i + 0.5 for i in range(matrix.shape[0])))
        axes.set_yticks(list(i + 0.5 for i in range(matrix.shape[1])))
        axes.set_xticklabels(statenames)
        axes.set_yticklabels(statenames)
        #axes.set_xlabel("rows")
        #axes.set_ylabel("columns")
        
    if scaling == "log10":
        matrix = np.log10(matrix)
    elif scaling == "-log10":
        matrix = -np.log10(matrix)
    elif scaling == "logit":
        matrix = np.log10(matrix/(1-matrix))
    else:
        pass
    
    size = matrix.shape[0]
    
    dx = 1-2*spacing
    dy = 1-2*spacing
    dz = matrix.flatten()
    #print dz
    
    #spacing = -0.5 + spacing*0.5
    
    xpos = np.zeros((size,size))
    for i in range(xpos.shape[0]):
        xpos[i,:] = spacing + i
    ypos = np.array([np.arange(spacing, size + spacing)]*size)
    zpos = np.zeros((size, size))
    xpos = xpos.flatten()
    ypos = ypos.flatten()
    zpos = zpos.flatten()
    #print (xpos)
    #print (ypos)
    #print (zpos)
    
    axes.bar3d(xpos, ypos, zpos, dx, dy, dz, *args, **kwargs)
    axes.view_init(angle[0], angle[1])
    
    if visualize is True:
        plt.show()
    
    return axes



def networkplot (matrix, visualize=False, axes=None, mode="", special=[0,1], openout=True, *args, **kwargs):
    """
    Customizable network display of a matrix.
    
    Parameters:
    -----------
    matrix : a 2d numpy array holding the transition matrix
    
    visualize: boolean
        If True, the resulting diagram will be shown directly.
    
    axes : an axes object to draw to; None means drawing to current axes.
    
    mode: different ways to display the data; combine from I & II:
    ================    =============================================================================
    value               description
    ================    =============================================================================
                        I. edges are drawn:
    "mp_neighbor"       from each node to its most probable neighbor(s)
    "mp_path"           along the most probable path from the first "special" node to the second
    "percolation"       according to a percolation threshold set via "special"
    
    "eachedge"          all edges are drawn if none of three options above are specified
    "noedge"            overrides edge drawing
    
                        II. nodes are colored according to:
    "p_eig"             their probability to be occupied during an infinite run (power method)
    "p_stay"            the probability to stay at them for the next time step
    "p_out"             the probability to move out of them at the next time step
    "p_in_raw"          the probability to arrive at them from anywhere else with equal probability   
    "p_in_inf"          the average probability to arrive at them during an infinite run   
    "btw_cent"          their value of betweenness-centrality
    "d_in"              their degree of incoming edges
    "d_out"             their degree of outgoing edges
    "d_all"             their degree counting edges in both directions
    "special"           wether they are the special nodes (orange) or not (70% grey)
    "t_spec"            the expected time to a special state
        
    "nonode"            overrides node drawing
      
    ================    =============================================================================
    
    special : a list of two nodes specified by their indices in the matrix
        This list will be used as the source and target for the "percolation threshold" and 
        "shortest path" algorithms.
    
    statenames : a list of labels
        Alias for networkx labels.
    
    openout : boolean
        If true, plotting is performed as a block, if false the components 
        are returned one by one (more convenient for making legends).
        
    Plus any arguments the networkx.draw_networkx() function would take.
    
    Returns:
    --------
    nodes : matplotlib artist or False
        The network nodes.
    
    edges : matplotlib artist(s) or False
        The network edges.
    
    labels : matplotlib artist or False
        The labels.
    
    if open is False:
    axes : matplotlib axes object
        Nodelabeling and axis plotting is turned off per default.
    
    """
      
    if sp.issparse(matrix):
        matrix = matrix.toarray()
    
    matrix = np.transpose(matrix)
    
    if axes is None:
        axes = plt.gca()
    
    axes.axis("off")
        
    node_number = matrix.shape[0]
        
    node_draw = False
    edge_draw = False
    
    ## edges
    if "mp_neighbor" in mode:
        graph = nx.MultiDiGraph()
        graph.add_nodes_from(list(range(node_number)))

        for i in range(node_number):
            line = np.array(matrix[i])
            line[i] = 0 # probability to stay in the same place is ignored
            
            if np.max(line) != 0.:
                link = np.where(np.isclose(line, np.max(line)))[0]
            else:
                link = []
            
            for l in link:
                graph.add_edge(i,l)
        
        edge_draw = True
        
        
    elif "mp_path" in mode:
        graph = nx.from_numpy_matrix(-np.log10(matrix+10**-200), create_using=nx.MultiDiGraph())
        nodes = nx.dijkstra_path(graph, special[0], special[1],'weight')
        #prob = 10**-(nx.dijkstra_path_length(graph, special[0], special[1],'weight'))	# edge probabilities	
        edges = list(zip(*[nodes[:-1], nodes[1:]]))
        
        graph = nx.MultiDiGraph()
        graph.add_nodes_from(list(range(node_number)))
        graph.add_edges_from(edges)
        if "node_color" not in kwargs:
            kwargs["node_color"] = ["white"] * node_number
        for i in nodes:
            kwargs["node_color"][i] = "r"
        
        edge_draw = True
        node_draw = True
        
    elif "percolation" in mode:
        graph = percolation(matrix, special[0], special[1])
        edge_draw = True
    
    elif "btw_cent" in mode: 
        graph = nx.from_numpy_matrix(-np.log(matrix), create_using=nx.MultiDiGraph())
                
    else:    
        graph = nx.from_numpy_matrix(matrix,create_using=nx.MultiDiGraph())
        
    
    ## nodes
    if "p_eig" in mode:
        if "throwex" not in kwargs:
            kwargs["throwex"] = True
        node_color = list(eigenone(np.transpose(matrix), alg="power", throwex=kwargs["throwex"]))
        node_draw = True
                
    elif "p_out" in mode:
        node_color = list(1-np.diag(matrix))
        node_draw = True
                
    elif "p_stay" in mode:
        node_color = list(np.diag(matrix))
        node_draw = True
    
    elif "p_in_raw" in mode:
        node_color = list(sum(matrix)/float(len(matrix)))
        node_draw = True
        
    elif "p_in_inf" in mode:
        node_color = list(sum(matrix)/eigenone(np.transpose(matrix), alg="linalg", throwex=kwargs["throwex"]))
        node_draw = True

    elif "btw_cent" in mode:
        node_color = list(nx.load_centrality(graph,weight='weight').values())
        node_draw = True

    elif "d_in" in mode:
        node_color = list(graph.in_degree().values())
        node_draw = True
            
    elif "d_out" in mode:
        node_color = list(graph.out_degree().values())
        node_draw = True
            
    elif "d_all" in mode: 
        node_color = list(graph.degree().values())
        node_draw = True       
    
    elif "special" in mode:
        node_color = ["0.7"] * node_number
        for i in special:
            node_color[i] = "orange"
        node_draw = True
    
    elif "t_spec" in mode:		
        tmatrix = np.delete(matrix, np.array(special), axis = 0)		
        tmatrix = np.delete(tmatrix, np.array(special), axis = 1)				
        fmatrix = np.linalg.inv(np.identity(node_number-len(special))-tmatrix)		
        times = np.sum(fmatrix, axis=1)		
			
        count=0		
        node_color = [1.]*node_number		
        for i in range(node_number):		
            if i in special:            		
                node_color[i] = 0.		
            else:		
                node_color[i] = times[count]		
                count += 1		
        node_draw = True   
        
    elif "mp_path" in mode:
        pass
    
    else: 
        node_color = kwargs.get("node_color", "r")
        node_draw = True
    
    if "node_color" not in kwargs:
        kwargs["node_color"] = node_color
    
    if "nonode" in mode:
        node_draw = False
    
    if "eachedge" in mode:
        edge_draw = True
    
    if "noedge" in mode:
        edge_draw = False
    
    ## drawing 
    if "cmap" not in kwargs:
        kwargs["cmap"] = "gray_r"
    
    if "font_color" not in kwargs and kwargs["cmap"] == "gray_r":
        kwargs["font_color"] = "red"
    
    if "statenames" in kwargs:
        kwargs["labels"] = dict(list(zip(list(range(len(kwargs["statenames"]))), kwargs["statenames"])))
    
    if "with_labels" not in kwargs:
        kwargs["with_labels"] = False
    
    if node_draw is False:
        kwargs["nodelist"] = node_draw
    if edge_draw is False:
        kwargs["edgelist"] = edge_draw
    
    if "pos" not in kwargs:
        kwargs["pos"] = nx.drawing.spring_layout(graph)
    
    if openout is False:
        axes = nx.draw_networkx(graph, *args, **kwargs)
        if visualize is True:
            plt.show()
        
        return axes
    
    else:
        if node_draw is True:
            node_draw = nx.draw_networkx_nodes(graph, *args, **kwargs)
        if edge_draw is True:
            edge_draw = nx.draw_networkx_edges(graph, *args, **kwargs)
        if kwargs["with_labels"] is False:
            label_draw = False
        else:
            label_draw = nx.draw_networkx_labels(graph, *args, **kwargs)
        
        if visualize is True:
            plt.show()
        
        return (node_draw, edge_draw, label_draw)

def percolation (matrix, source, target, verbose=False):
    """
    Returns a graph where percolation is just possible between source and 
    target.
    
    Parameters:
    -----------
    matrix : a 2d numpy array holding the transition matrix
    
    source : a node, given by its index in the matrix
    
    target : a node, given by its index in the matrix
    
    verbose : print out percentage of edges retained & minimal edge weight
    
    Returns:
    --------
    graph : networkx graph object
        Contains all nodes, but only edges whose weight is above the 
        percolation threshold defined by the two nodes provided. This also 
        works correctly if source and target node are identical.
    
    """
        
    if sp.issparse(matrix):
        matrix = matrix.toarray()
    else:
        matrix = np.array(matrix)
        
    graph = nx.MultiDiGraph()
    graph.add_nodes_from(list(range(matrix.shape[0])))
    
    if source == target:
        for (i,j) in list(zip(*np.where(matrix >= matrix[source,source]))):
            graph.add_edge(i, j, weight=matrix[i,j])    

    else:
        while not nx.has_path(graph, source, target):
            for (i,j) in list(zip(*np.where(matrix == np.max(matrix)))):
                if verbose is True:
                    minweight=matrix[i,j]
                graph.add_edge(i, j, weight=matrix[i,j])
                matrix[i,j] = 0
    
    if verbose is True:
        print("Number of edges retained: %s" %(len(graph.edges())))
        print("Minimum edge weight: %s" %(minweight))

    return graph

def eigenone (matrix, alg="linalg", throwex=False, verbose=False):
    """
    Returns the eigenvector corresponding to a single maximal eigenvalue of 
    one of a matrix.
    
    Parameters:
    -----------
    matrix : a 2d numpy array or a 2d scipy sparse array
    
    alg : string, an algorithm to calculate the vector. 
        Either "linalg" for using LAPACK (numpy matrix) or ARPACK (sparse 
        matrix), or "power" for the power method (accuracy = 1e-15).
    
    throwex : boolean
        If true, throw exception in case of bad behavior of the function.
        For the power method, this occurs when there is no convergence after 
        100 steps; for the linalg method, this occurs if the maximal 
        eigenvalue of the matrix is not single or not equal to 1.0.
        If false, there will only be a printed warning.
    
    verbose : boolean
        If true, diagnostic intermediate results will be printed out.
    
    Note:
    -----
    It will not be checked previously if the matrix actually has a single 
        maximal eigenvalue of one. If in doubt, use the "linalg" algorithm 
        and turn on the "throwex" option.
    
    Returns:
    --------
    vector : 1d numpy array
        The eigenvector corresponding to a single (= of multiplicity one) 
        maximal eigenvalue of one of a matrix.
        
    Warning:
    --------
    This function automatically returns the normalized absolute of the 
        desired eigenvector. Small negative values may arise due to numerical
        errors (see LAPACK/ARPACK documentation).
    
    """
    
    if alg == "linalg":
        if sp.issparse(matrix):
            val,vec = sp.linalg.eigs(matrix, k=2, which = "LM", maxiter = 58)
            if verbose is True:
                print("The two largest eigenvalues:")
                print(val)
            val = val.real
            vec = vec.real
            if verbose is True:
                print("Eigenvectors (real parts only):")
                print(vec)
            
            if throwex is True:
                test = np.where(val == np.max(val))
                if test.shape[0] != 1:
                    raise ValueError("Maximum eigenvalue has multiplicity > 1: %s" %(test.shape[0]))
                if not np.allclose(np.max(val), 1.0, rtol=0, atol=1e-15):
                    raise ValueError("Maximum eigenvalue does not equal 1: %.15f" %(np.max(val)))           
            
            vector = vec[:,np.argmax(val)] 
            if verbose is True:
                print("absolute vector sum, unscaled: %.10f" %(np.sum(np.abs(vector))))
            vector = np.abs(vector / np.sum(np.abs(vector)))
            if verbose is True:
                print("vector sum, scaled: %.1f" %(np.sum(vector)))
        
        else:
            val,vec = np.linalg.eig(matrix)
            if verbose is True:
                print("The two largest eigenvalues:")
                print(val[0:2])
            val = val.real
            vec = vec.real
            if verbose is True:
                print("Eigenvectors (real parts only):")
                print(vec[:,0:2])
            
            test = np.where(val == np.max(val))
            if len(test) != 1:
                #print ("Maximum eigenvalue has multiplicity > 1: %s" %(test.shape[0]))
                if throwex is True:
                    raise ValueError("Maximum eigenvalue has multiplicity > 1: %s" %(test.shape[0]))  
            if not np.allclose(np.max(val), 1.0, rtol=0, atol=1e-12):
                #print ("Maximum eigenvalue does not equal 1: %.15f" %(np.max(val)))
                if throwex is True:
                    raise ValueError("Maximum eigenvalue does not equal 1: %.15f" %(np.max(val)))

            vector = vec[:,np.argmax(val)]
            if verbose is True:
                print("absolute vector sum, unscaled: %.10f" %(np.sum(np.abs(vector))))
            vector = np.abs(vector / np.sum(np.abs(vector)))
            if verbose is True:
                print("vector sum, scaled: %.1f" %(np.sum(vector)))
            
    elif alg == "power":
            length = matrix.shape[0]
            print(length)
            V0 = np.zeros((length))
            V1 = (1./length)*np.ones((length), dtype = float)
            
            count = 0
            countall = 100
            dif = 1.
            while dif >=10e-15:
                if count <= 100:
                    V0 = V1
                    if sp.issparse(matrix):
                        vec = matrix*V1
                    else:
                        vec = np.dot(matrix, V1)
                    V1 = vec/np.linalg.norm(vec)
                    dif = np.max(np.abs(V0-V1))
                    if verbose is True:
                        print("Maximum difference between vectors: %.15f" %(dif))
                    count += 1
                else:
                    if throwex is False:
                        print("Surpassed %s iterations!" %(max(count-1, countall)))
                        cont = input("Continue (Y/N)?")
                        if cont == "N" or cont == "n":
                            verbose = True
                            break
                        else:
                            countall += count-1
                            count = 0
                    else:
                        raise RuntimeError("Power method did not converge after 100 iterations.")
                
            vector = np.abs(V1/np.sum(np.abs(V1)))
            if dif < 10e-15:
                print("Power method: %s iterations to convergence." %(max(count-1, countall)))
            else:
                print("Power method did not converge after %s iterations." %(max(count-1, countall)))
            if verbose is True:
                print (vector)
    else:
        pass
    
    return vector
    

def appromatrix (matrix, s, sparse=None, testing=False, verbose=False):
    """
    Returns an approximate (sparse) matrix.
    
    Parameters:
    -----------
    matrix : a 2d numpy array holding the transition matrix
    
    s : a cutoff value for the column sum in the interval ]0.0; 1.0] 
        At least one value per column will always be kept.
    
    sparse : output format for the approximated matrix 
        See sparse matrix classes in the scipy.sparse documentation.
        The most frequent will be:
        -   "csc" (compressed column; fast calculations) 
        -   "lil" (list of lists; strong compression)
        Defaults to None, i.e. no sparse format applied
    
    testing : execute efficiency test and an accuracy test 
        If True, the function gives out a tuple of values instead of just the
        matrix: (matrix, efficiency ratio, accuracy p-value)
    
    verbose : boolean
        If True, print out intermediate results of the approximation process
        
    Returns:
    --------
    sparsematrix : 2d scipy sparse array or 2d numpy array
        The approximate matrix, optionally in a sparse format.
        
    If testing is True:
    
    sparsematrix : 2d scipy sparse array or 2d numpy array
        The approximate matrix, optionally in a sparse format.
    
    efficiency: float
        One minus sparse matrix to original matrix bytesize ratio.
    
    p_value: float
        P-value of a G-test on the eigenvectors of the sparse vs. original 
        matrix. SciPy's own G-test function is used from SciPy version 0.13.0
        upwards.
    
    
    """
    try:
        import scipy.stats as st
    except ImportError:
        raise ImportError("Scipy statistics module unavailable.")
    
    if s < 0.0 or s > 1.0:
        #print ("Value for s out of bounds.")
        raise ValueError("Value for s out of bounds.")
    
    states_num = matrix.shape[0]

    mask = np.zeros_like(matrix, dtype=bool)
        
    for i in range(states_num):
        column = np.array(matrix[:,i])
        col_sum = np.sum(column)
        ranks = np.argsort(column)[::-1]    
        
        mask[ranks[0],i] = True
        mask[ranks[1],i] = True

        j = 2
        while j < states_num and np.sum(mask[:,i]*column) <= s*col_sum :
            mask[ranks[j],i] = True
            j += 1
        
        while j < states_num and column[ranks[j-1]] == column[ranks[j]]:
            mask[ranks[j],i] = True
            j += 1
            
        mask[i-1,i] = True
        if i+1 < states_num :
            mask[i+1,i] = True
        else:
            mask[0,i] = True
        
        if verbose is True:
            print("%s values kept in column %s" %(np.sum(mask[:,i]), i))
        
    if np.trace(matrix) != 0 and not np.any(np.diag(mask)):   
        mask[np.argmax(np.diag(matrix))] = True 
    if verbose is True:
        print("%s values kept on the diagonal" %(np.count_nonzero(np.diag(matrix))))
    
    #result = np.zeros_like(matrix,dtype=float)
    result = (matrix*mask)
    for i in range(states_num):
        result[:,i] = result[:,i]/np.sum(result[:,i])
      
    if sparse == "bsr":
        sparseresult = sp.bsr_matrix(result, dtype=np.float_)
    elif sparse == "coo":
        sparseresult = sp.coo_matrix(result, dtype=np.float_)
    elif sparse == "csc":
        sparseresult = sp.csc_matrix(result, dtype=np.float_)
    elif sparse == "csr":
        sparseresult = sp.csr_matrix(result, dtype=np.float_)
    elif sparse == "dia":
        sparseresult = sp.dia_matrix(result, dtype=np.float_)
    elif sparse == "dok":
        sparseresult = sp.dok_matrix(result, dtype=np.float_)
    elif sparse == "lil":
        sparseresult = sp.lil_matrix(result, dtype=np.float_)
    else :
        sparseresult = result
    
    if testing is True:
        newzeros = (np.count_nonzero(matrix) - np.sum(mask))
        zeratio = newzeros/float(states_num)**2
        if sparse is not None:        
            try:
                sparse_size = sparseresult.data.nbytes #+ sparseresult.indptr.nbytes #+ sparseresult.indices.nbytes
            except AttributeError:
                import sys
                sparse_size = sys.getsizeof(sparseresult)
        else:
            sparse_size = sparseresult.nbytes
            
        efficiency = 1. - float(sparse_size) / float(matrix.nbytes)
        
        print("\nEfficiency Test")
        print("-" * 10)
        print("%s new zeros (%.1f percent of the matrix) were introduced by approximation." %(newzeros, zeratio*100))
        print("The approximate efficiency (1 - ratio of memory uptake, excluding metadata) is %.4f." %(efficiency))
        print("The data in the sparse matrix takes up ~%s bytes, and the original ~%s bytes.\n \n" %(sparse_size, matrix.nbytes))
    
        vector0 = eigenone(matrix) #, throwex = True
        try:
            vector1 = eigenone(sparseresult)
        except Exception as error :
            print("WARNING - An exception occurred while calculating the sparse matrix eigenvector:")
            print("%s: %s." %(type(error).__name__, error))
            print("Using expanded matrix instead. \n")
            vector1 = eigenone(result)
        
        try:
            g_value, p_value = st.power_divergence(vector1, vector0, lambda_="log-likelihood")
            if verbose is True:
                print("G-Test with scipy power divergence. \n")
        except:
            (g_value, p_value) = testvectors(vector0, vector1) 
            if verbose is True:
                print("G-Test with testvectors. \n")
        
    
        if p_value <= 0.01:
            stars = "**"
        elif p_value <= 0.05:
            stars = "*"
        else:
            stars = ""
    
        print("Accuracy Test")
        print("-" * 10)
        print("Results of a G-test comparing the approximate eigenvector to the original: \n")
        print("degrees of freedom: \t %s" %(states_num-1))
        print("G-value: \t \t %.4f" %(g_value))
        print("p-value: \t \t %.4f %s \n \n" %(p_value, stars))

        return (sparseresult, efficiency, p_value)
    
    else:
        return sparseresult

def testvectors(original, fake, ignorext=True, verbose=False):
    """
    G-test comparing two vectors.
    
    Parameters:
    -----------
    original : 1D numpy array 
        Distribution against which the test will be performed
    
    fake : 1D numpy array 
        Distribution to be tested
    
    ignorext : boolean 
        Ignore zeros and negative values in the vector and compare only 
        positive vector entries
    
    verbose : boolean 
        If True, print out the summands for the G-statistic 
    
    Returns:
    --------
    g_value : float 
        The value of G
    
    p_value : float 
        P-value of the G-test
    
    Note:
    ------
    Compare the scipy.stats.power_divergence() function in scipy >= 0.13.0 
        for a better implementation. P-values for G are determined on a 
        chi-squared distribution.
    
    
    """
    try:
        import scipy.stats as st
    except ImportError:
        raise ImportError("Scipy statistics module unavailable.")
    
    if fake.shape[0] != original.shape[0]:
        raise ValueError("The two vectors need to have the same length!")
    
       
    summands = fake*np.log(fake/original)
    if verbose is True:
        print("The summands of G:")        
        print(summands)
    if ignorext is True:
        g_value = np.nansum(2.*summands[~np.isinf(summands)])
    else:
        g_value = np.sum(2.*summands)
    p_value = st.chisqprob(g_value, fake.shape[0]-1)
    
    return (g_value, p_value) 
