"""@namespace IMP.isd.gmm_tools
   Tools for handling Gaussian Mixture Models.
"""

from __future__ import print_function
import IMP
import IMP.core
import IMP.algebra
import IMP.atom
import IMP.em
import numpy as np
import numpy.linalg
import sys,os
import itertools

try:
    import sklearn.mixture
    nosklearn=False
except:
    nosklearn=True
from math import exp,sqrt,copysign

def decorate_gmm_from_text(in_fn,
                           ps,
                           mdl,
                           transform=None,
                           radius_scale=1.0,
                           mass_scale=1.0):
    """ read the output from write_gmm_to_text, decorate as Gaussian and Mass"""
    ncomp=0
    added_ps = itertools.count(1)
    with open(in_fn,'r') as inf:
        for l in inf:
            if l[0]!='#':
                if ncomp>len(ps)-1:
                    ps.append(IMP.Particle(mdl, "GMM%d" % next(added_ps)))
                p = ps[ncomp]
                fields=l.split('|')
                weight=float(fields[2])
                center=list(map(float,fields[3].split()))
                covar=np.array(list(map(float,
                                        fields[4].split()))).reshape((3,3))
                #print('on particle',ncomp)
                shape=IMP.algebra.get_gaussian_from_covariance(covar.tolist(),
                                                 IMP.algebra.Vector3D(center))
                if not IMP.core.Gaussian.get_is_setup(p):
                    g = IMP.core.Gaussian.setup_particle(ps[ncomp],shape)
                else:
                    g = IMP.core.Gaussian(ps[ncomp])
                    g.set_gaussian(shape)
                if not IMP.atom.Mass.get_is_setup(p):
                    IMP.atom.Mass.setup_particle(p,weight*mass_scale)
                else:
                    IMP.atom.Mass(p).set_mass(weight*mass_scale)
                rmax = sqrt(max(g.get_variances()))*radius_scale
                if not IMP.core.XYZR.get_is_setup(ps[ncomp]):
                    IMP.core.XYZR.setup_particle(ps[ncomp],rmax)
                else:
                    IMP.core.XYZR(ps[ncomp]).set_radius(rmax)
                if not transform is None:
                    IMP.core.transform(IMP.core.RigidBody(ps[ncomp]),transform)
                ncomp+=1

def write_gmm_to_text(ps,out_fn, comments=[]):
    """write a list of gaussians to text. must be decorated as Gaussian and Mass"""
    print('will write GMM text to',out_fn)
    with open(out_fn,'w') as outf:
        for comment in comments:
            outf.write('# ' + comment + '\n')
        outf.write('#|num|weight|mean|covariance matrix|\n')
        for ng,g in enumerate(ps):
            shape=IMP.core.Gaussian(g).get_gaussian()
            weight=IMP.atom.Mass(g).get_mass()
            covar=[c for row in IMP.algebra.get_covariance(shape) for c in row]
            mean=list(shape.get_center())
            fm=[ng,weight]+mean+covar
            outf.write('|{}|{}|{} {} {}|{} {} {} {} {} {} {} {} {}|\n'.format(*fm))

def gmm2map(to_draw,voxel_size,bounding_box=None,origin=None, fast=False, factor=2.5):
    if type(to_draw[0]) in (IMP.Particle,IMP.atom.Hierarchy,IMP.core.Hierarchy):
        ps=to_draw
    elif type(to_draw[0])==IMP.core.Gaussian:
        ps=[g.get_particle() for g in to_draw]
    else:
        print('ps must be Particles or Gaussians')
        return
    if bounding_box is None:
        if len(ps)>1:
            s=IMP.algebra.get_enclosing_sphere([IMP.core.XYZ(p).get_coordinates() for p in ps])
            s2=IMP.algebra.Sphere3D(s.get_center(),s.get_radius()*3)
        else:
            g=IMP.core.Gaussian(ps[0]).get_gaussian()
            s2=IMP.algebra.Sphere3D(g.get_center(),max(g.get_variances())*3)
        bounding_box=IMP.algebra.get_bounding_box(s2)
    shapes=[]
    weights=[]
    for p in ps:
        shapes.append(IMP.core.Gaussian(p).get_gaussian())
        weights.append(IMP.atom.Mass(p).get_mass())
    print('rasterizing')
    if fast:
        grid=IMP.algebra.get_rasterized_fast(shapes,weights,voxel_size,bounding_box,factor)
    else:
        grid=IMP.algebra.get_rasterized(shapes,weights,voxel_size,bounding_box)
    print('creating map')
    d1=IMP.em.create_density_map(grid)
    if origin is not None:
        d1.set_origin(origin)
    return d1
def write_gmm_to_map(to_draw,out_fn,voxel_size,bounding_box=None,origin=None, fast=False, factor=2.5):
    """write density map from GMM. input can be either particles or gaussians"""
    d1 = gmm2map(to_draw,voxel_size,bounding_box,origin, fast)
    print('will write GMM map to',out_fn)
    IMP.em.write_map(d1,out_fn,IMP.em.MRCReaderWriter())
    del d1

def write_sklearn_gmm_to_map(gmm,out_fn,apix=0,bbox=None,dmap_model=None):
    """write density map directly from sklearn GMM (kinda slow) """
    ### create density
    if not dmap_model is None:
        d1=IMP.em.create_density_map(dmap_model)
    else:
        d1=IMP.em.create_density_map(bbox,apix)

    ### fill it with values from the GMM
    print('getting coords')
    nvox=d1.get_number_of_voxels()
    apos=[list(d1.get_location_by_voxel(nv)) for nv in range(nvox)]

    print('scoring')
    scores=gmm.score(apos)

    print('assigning')
    for nv, score in enumerate(scores):
        d1.set_value(nv,exp(score))
    print('will write GMM map to',out_fn)
    IMP.em.write_map(d1,out_fn,IMP.em.MRCReaderWriter())

def draw_points(pts,out_fn,trans=IMP.algebra.get_identity_transformation_3d(),
                                use_colors=False):
    """ given some points (and optional transform), write them to chimera 'bild' format
    colors flag only applies to ellipses, otherwise it'll be weird"""
    with open(out_fn,'w') as outf:
        #print 'will draw',len(pts),'points'
        # write first point in red
        pt=trans.get_transformed(IMP.algebra.Vector3D(pts[0]))
        start=0
        if use_colors:
            outf.write('.color 1 0 0\n.dotat %.2f %.2f %.2f\n'
                       %(pt[0],pt[1],pt[2]))
            start=1

        # write remaining points in green
        if use_colors:
            outf.write('.color 0 1 0\n')
            colors=['0 1 0','0 0 1','0 1 1']
        for nt,t in enumerate(pts[start:]):
            if use_colors and nt%2==0:
                outf.write('.color %s\n' % colors[nt/2])
            pt=trans.get_transformed(IMP.algebra.Vector3D(t))
            outf.write('.dotat %.2f %.2f %.2f\n' %(pt[0],pt[1],pt[2]))



def create_gmm_for_bead(mdl,
                        particle,
                        n_components,
                        sampled_points=100000,
                        num_iter=100):
    print('fitting bead with',n_components,'gaussians')
    dmap=IMP.em.SampledDensityMap([particle],1.0,1.0,
                                  IMP.atom.Mass.get_mass_key(),3,IMP.em.SPHERE)
    IMP.em.write_map(dmap,'test_intermed.mrc')
    pts=IMP.isd.sample_points_from_density(dmap,sampled_points)
    draw_points(pts,'pts.bild')
    density_particles=[]
    fit_gmm_to_points(pts,n_components,mdl,
                      density_particles,
                      num_iter,'full',
                      mass_multiplier=IMP.atom.Mass(particle).get_mass())
    return density_particles,dmap

def sample_and_fit_to_particles(model,
                                fragment_particles,
                                num_components,
                                sampled_points=1000000,
                                simulation_res=0.5,
                                voxel_size=1.0,
                                num_iter=100,
                                covariance_type='full',
                                multiply_by_total_mass=True,
                                output_map=None,
                                output_txt=None):
    density_particles=[]
    if multiply_by_total_mass:
        mass_multiplier=sum((IMP.atom.Mass(p).get_mass() for p in set(fragment_particles)))
        print('add_component_density: will multiply by mass',mass_multiplier)

    # simulate density from ps, then calculate points to fit
    print('add_component_density: sampling points')
    dmap=IMP.em.SampledDensityMap(fragment_particles,simulation_res,voxel_size,
                                 IMP.atom.Mass.get_mass_key(),3)
    dmap.calcRMS()
    #if not intermediate_map_fn is None:
    #   IMP.em.write_map(dmap,intermediate_map_fn)
    pts=IMP.isd.sample_points_from_density(dmap,sampled_points)

    # fit GMM
    print('add_component_density: fitting GMM to',len(pts),'points')
    fit_gmm_to_points(points=pts,
                      n_components=num_components,
                      mdl=model,
                      ps=density_particles,
                      num_iter=num_iter,
                      covariance_type=covariance_type,
                      mass_multiplier=mass_multiplier)

    if not output_txt is None:
        write_gmm_to_text(density_particles,output_txt)
    if not output_map is None:
        write_gmm_to_map(to_draw=density_particles,
                         out_fn=output_map,
                         voxel_size=voxel_size,
                         bounding_box=IMP.em.get_bounding_box(dmap))

    return density_particles

def fit_gmm_to_points(points,
                      n_components,
                      mdl,
                      ps=[],
                      num_iter=100,
                      covariance_type='full',
                      min_covar=0.001,
                      init_centers=[],
                      force_radii=-1.0,
                      force_weight=-1.0,
                      mass_multiplier=1.0):
    """fit a GMM to some points. Will return the score and the Akaike score.
    Akaike information criterion for the current model fit. It is a measure
    of the relative quality of the GMM that takes into account the
    parsimony and the goodness of the fit.
    if no particles are provided, they will be created

    points:            list of coordinates (python)
    n_components:      number of gaussians to create
    mdl:               IMP Model
    ps:                list of particles to be decorated. if empty, will add
    num_iter:          number of EM iterations
    covariance_type:   covar type for the gaussians. options: 'full', 'diagonal', 'spherical'
    min_covar:         assign a minimum value to covariance term. That is used to have more spherical
                       shaped gaussians
    init_centers:      initial coordinates of the GMM
    force_radii:       fix the radii (spheres only)
    force_weight:      fix the weights
    mass_multiplier:   multiply the weights of all the gaussians by this value
    dirichlet:         use the DGMM fitting (can reduce number of components, takes longer)
    """


    new_sklearn = False
    try:
        from sklearn.mixture import GMM
    except ImportError:
        from sklearn.mixture import GaussianMixture
        new_sklearn = True

    print('creating GMM with n_components',n_components,'n_iter',num_iter,'covar type',covariance_type)
    if new_sklearn:
        # aic() calls size() on points, so it needs to be
        # a numpy array, not a list
        points = np.array(points)
        weights_init = precisions_init = None
        if force_radii != -1.0:
            print('warning: radii can no longer be forced, but setting '
                  'initial values to ', force_radii)
            precisions_init = np.array([[1./force_radii]*3
                                       for i in range(n_components)])
        if force_weight != -1.0:
            print('warning: weights can no longer be forced, but setting '
                  'initial values to ', force_weight)
            weights_init = np.array([force_weight]*n_components)

        gmm = GaussianMixture(n_components=n_components,
                              max_iter=num_iter,
                              covariance_type=covariance_type,
                              weights_init=weights_init,
                              precisions_init=precisions_init,
                              means_init=None if init_centers==[]
                                              else init_centers)
    else:
        params='m'
        init_params='m'
        if force_radii==-1.0:
            params+='c'
            init_params+='c'
        else:
            covariance_type='spherical'
            print('forcing spherical with radii',force_radii)

        if force_weight==-1.0:
            params+='w'
            init_params+='w'
        else:
            print('forcing weights to be',force_weight)

        gmm = GMM(n_components=n_components, n_iter=num_iter,
                  covariance_type=covariance_type, min_covar=min_covar,
                  params=params, init_params=init_params)
        if force_weight!=-1.0:
            gmm.weights_=np.array([force_weight]*n_components)
        if force_radii!=-1.0:
            gmm.covars_=np.array([[force_radii]*3 for i in range(n_components)])
        if init_centers!=[]:
            gmm.means_=init_centers
    print('fitting')
    model=gmm.fit(points)
    score=gmm.score(points)
    akaikescore=model.aic(points)
    #print('>>> GMM score',gmm.score(points))

    ### convert format to core::Gaussian
    if new_sklearn:
        covars = gmm.covariances_
    else:
        covars = gmm.covars_
    for ng in range(n_components):
        covar=covars[ng]
        if covar.size==3:
            covar=np.diag(covar).tolist()
        else:
            covar=covar.tolist()
        center=list(gmm.means_[ng])
        weight=mass_multiplier*gmm.weights_[ng]
        if ng>=len(ps):
            ps.append(IMP.Particle(mdl))
        shape=IMP.algebra.get_gaussian_from_covariance(covar,IMP.algebra.Vector3D(center))
        g=IMP.core.Gaussian.setup_particle(ps[ng],shape)
        IMP.atom.Mass.setup_particle(ps[ng],weight)
        IMP.core.XYZR.setup_particle(ps[ng],sqrt(max(g.get_variances())))

    return (score,akaikescore)

def fit_dirichlet_gmm_to_points(points,
                      n_components,
                      mdl,
                      ps=[],
                      num_iter=100,
                      covariance_type='full',
                      mass_multiplier=1.0):
    """fit a GMM to some points. Will return core::Gaussians.
    if no particles are provided, they will be created

    points:            list of coordinates (python)
    n_components:      number of gaussians to create
    mdl:               IMP Model
    ps:                list of particles to be decorated. if empty, will add
    num_iter:          number of EM iterations
    covariance_type:   covar type for the gaussians. options: 'full', 'diagonal', 'spherical'
    init_centers:      initial coordinates of the GMM
    force_radii:       fix the radii (spheres only)
    force_weight:      fix the weights
    mass_multiplier:   multiply the weights of all the gaussians by this value
    """


    new_sklearn = True
    try:
        from sklearn.mixture import BayesianGaussianMixture
    except ImportError:
        from sklearn.mixture import DPGMM
        new_sklearn = False

    ### create and fit GMM
    print('using dirichlet prior')
    if new_sklearn:
        gmm = BayesianGaussianMixture(
                weight_concentration_prior_type='dirichlet_process',
                n_components=n_components, max_iter=num_iter,
                covariance_type=covariance_type)
    else:
        gmm = DPGMM(n_components=n_components, n_iter=num_iter,
                    covariance_type=covariance_type)

    gmm.fit(points)

    #print('>>> GMM score',gmm.score(points))

    #print gmm.covars_
    #print gmm.weights_
    #print gmm.means_
    ### convert format to core::Gaussian
    if not new_sklearn:
        gmm.precisions_ = gmm.precs_
    for ng in range(n_components):
        invcovar=gmm.precisions_[ng]
        covar=np.linalg.inv(invcovar)
        if covar.size==3:
            covar=np.diag(covar).tolist()
        else:
            covar=covar.tolist()
        center=list(gmm.means_[ng])
        weight=mass_multiplier*gmm.weights_[ng]
        if ng>=len(ps):
            ps.append(IMP.Particle(mdl))
        shape=IMP.algebra.get_gaussian_from_covariance(covar,IMP.algebra.Vector3D(center))
        g=IMP.core.Gaussian.setup_particle(ps[ng],shape)
        IMP.atom.Mass.setup_particle(ps[ng],weight)
        IMP.core.XYZR.setup_particle(ps[ng],sqrt(max(g.get_variances())))
