from __future__ import print_function
from qelimutils import *
from itertools import ifilter


class TermValue(object):
    def __init__(self, term):
        self.term = term

class InfValue(object):
    pass

class TermEpsilonValue(object):
    def __init__(self, term):
        self.term = term


class Bound(object):
    def __init__(self, is_eq, bound, atom):
        self.is_eq = is_eq
        self.bound = bound
        self.atom = atom


def collect_bounds(formula, var):
    for atom in get_atoms(formula):
        linear = to_linear(atom, False)
        c = coeff(linear, var)
        if c > 0:
            linear = to_linear(atom, True)
            c = coeff(linear, var)
        if c < 0:
            linear = scale(linear, abs(1/c))
            del linear.vars[var]
            t = to_term(linear, True)
            # (x < t) --> t is an upper bound
            yield Bound(False, t, atom)
            if not linear.strict:
                # (x <= t) becomes (x < t) | (x = t)
                yield Bound(True, t, atom)


def find_bound_value(bounds, var, model):
    value = model[var].constant_value()
    found = None
    for b in bounds:
        bv = model[b.bound].constant_value()
        if b.is_eq:
            if value == bv:
                return TermValue(b.bound)
        elif value < bv:
            if found is None or bv < found[1]:
                found = (b.bound, bv)
    if found is not None:
        return TermEpsilonValue(found[0])
    else:
        return InfValue()


def subst_value(formula, var, value):
    subst = {}
    if isinstance(value, TermValue):
        subst[var] = value.term
    else:
        for atom in get_atoms(formula):
            linear = to_linear(atom, False)
            c = coeff(linear, var)
            atomval = None
            if c != 0:
                linear = scale(linear, abs(1/c))
                del linear.vars[var]
                t = to_term(linear, True)
                if c > 0:
                    # 0 <= var + t --> var >= -t
                    if isinstance(value, InfValue):
                        # (+oo >= -t) always holds
                        atomval = TRUE()
                    else:
                        # (value - epsilon >= -t) --> (value > -t)
                        atomval = LT(Times(Real(-1), t), value.term)
                else:
                    # 0 <= -var + t --> var <= t
                    if isinstance(value, InfValue):
                        # (+oo <= t) never holds
                        atomval = FALSE()
                    else:
                        # (value - epsilon <= t) --> (value <= t)
                        atomval = LE(value.term, t)
                subst[atom] = atomval
    return formula.substitute(subst).simplify()            


def vts_model(formula, to_elim, model):
    for x in to_elim:
        bounds = list(collect_bounds(formula, x))
        value = find_bound_value(bounds, x, model)
        formula = subst_value(formula, x, value)
    return formula


def vts(formula, to_elim):
    f = nnf(formula)
    ret = []
    with Solver() as smt:
        smt.add_assertion(f)
        #i = 0
        while smt.solve():
            model = smt.get_model()
            #start = time.time()
            disj = vts_model(f, to_elim, model)
            #end = time.time()
            ret.append(disj)
            smt.add_assertion(Not(disj))
            #print(';; disjunct %d, time: %.3f' % (i+1, (end - start)))
            #i += 1
    return Or(*ret).simplify()


