"""Filter AST.

Note that it is easily possible to construct an AST that is inconsistent with
a given schema. Furthermore, it is possible to construct a semantically invalid
AST which that cannot be parsed correctly or includes contradicting statements.
The AST nodes do not (and cannot) check such issues.

For example, consider the following AST:

>>> Any(ns.bse.collection,
...     And(
...         Equals('hello'),
...         Is('hello world'),
...         Any(ns.bse.tag, Equals('world')),
...         Any(ns.bst.label, Equals('world')),
...         All(ns.bst.label, Not(Equals('world'))),
...     )
... )

This AST has multiple issues that are not verified upon its creation:
* A condition on a non-literal.
* A Filter on a literal.
* Conditions exclude each other
* The predicate along the branch have incompatible domains and ranges.

"""
# imports
from collections import abc
import typing

# bsfs imports
from bsfs.utils import URI, typename, normalize_args

# exports
__all__ : typing.Sequence[str] = (
    # base classes
    'FilterExpression',
    'PredicateExpression',
    # predicate expressions
    'OneOf',
    'Predicate',
    # branching
    'All',
    'Any',
    # aggregators
    'And',
    'Or',
    # value matchers
    'Equals',
    'Substring',
    'EndsWith',
    'StartsWith',
    # range matchers
    'GreaterThan',
    'LessThan',
    # misc
    'Has',
    'Is',
    'Not',
    )


## code ##

# pylint: disable=too-few-public-methods # Many expressions use mostly magic methods

class _Expression(abc.Hashable):
    def __repr__(self) -> str:
        """Return the expressions's string representation."""
        return f'{typename(self)}()'

    def __hash__(self) -> int:
        """Return the expression's integer representation."""
        return hash(type(self))

    def __eq__(self, other: typing.Any) -> bool:
        """Return True if *self* and *other* are equivalent."""
        return isinstance(other, type(self))


class FilterExpression(_Expression):
    """Generic Filter expression."""


class PredicateExpression(_Expression):
    """Generic Predicate expression."""


class _Branch(FilterExpression):
    """Branch the filter along a predicate."""

    # predicate to follow.
    predicate: PredicateExpression

    # child expression to evaluate.
    expr: FilterExpression

    def __init__(
            self,
            predicate: typing.Union[PredicateExpression, URI],
            expr: FilterExpression,
            ):
        # process predicate argument
        if isinstance(predicate, URI):
            predicate = Predicate(predicate)
        elif not isinstance(predicate, PredicateExpression):
            raise TypeError(predicate)
        # process expression argument
        if not isinstance(expr, FilterExpression):
            raise TypeError(expr)
        # assign members
        self.predicate = predicate
        self.expr = expr

    def __repr__(self) -> str:
        return f'{typename(self)}({self.predicate}, {self.expr})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.predicate, self.expr))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) \
           and self.predicate == other.predicate \
           and self.expr == other.expr

class Any(_Branch):
    """Any (and at least one) triple matches."""


class All(_Branch):
    """All (and at least one) triples match."""


class _Agg(FilterExpression, abc.Collection):
    """Combine multiple expressions."""

    # child expressions
    expr: typing.Set[FilterExpression]

    def __init__(
            self,
            *expr: typing.Union[FilterExpression,
                                typing.Iterable[FilterExpression],
                                typing.Iterator[FilterExpression]]
            ):
        # unfold arguments
        unfolded = set(normalize_args(*expr))
        # check type
        if not all(isinstance(e, FilterExpression) for e in unfolded):
            raise TypeError(expr)
        # FIXME: Require at least one child expression?
        # assign member
        self.expr = unfolded

    def __contains__(self, expr: typing.Any) -> bool:
        """Return True if *expr* is among the child expressions."""
        return expr in self.expr

    def __iter__(self) -> typing.Iterator[FilterExpression]:
        """Iterator over child expressions."""
        return iter(self.expr)

    def __len__(self) -> int:
        """Number of child expressions."""
        return len(self.expr)

    def __repr__(self) -> str:
        return f'{typename(self)}({self.expr})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), tuple(sorted(self.expr, key=repr))))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) and self.expr == other.expr


class And(_Agg):
    """All conditions match."""


class Or(_Agg):
    """At least one condition matches."""


class Not(FilterExpression):
    """Invert a statement."""

    # child expression
    expr: FilterExpression

    def __init__(self, expr: FilterExpression):
        # check argument
        if not isinstance(expr, FilterExpression):
            raise TypeError(expr)
        # assign member
        self.expr = expr

    def __repr__(self) -> str:
        return f'{typename(self)}({self.expr})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.expr))

    def __eq__(self, other: typing.Any) -> bool:
        return super().__eq__(other) and self.expr == other.expr


class Has(FilterExpression):
    """Has predicate N times"""

    # predicate to follow.
    predicate: PredicateExpression

    # target count
    count: FilterExpression

    def __init__(
            self,
            predicate: typing.Union[PredicateExpression, URI],
            count: typing.Optional[typing.Union[FilterExpression, int]] = None,
            ):
        # check predicate
        if isinstance(predicate, URI):
            predicate = Predicate(predicate)
        elif not isinstance(predicate, PredicateExpression):
            raise TypeError(predicate)
        # check count
        if count is None:
            count = GreaterThan(1, strict=False)
        elif isinstance(count, int):
            count = Equals(count)
        elif not isinstance(count, FilterExpression):
            raise TypeError(count)
        # assign members
        self.predicate = predicate
        self.count = count

    def __repr__(self) -> str:
        return f'{typename(self)}({self.predicate}, {self.count})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.predicate, self.count))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) \
           and self.predicate == other.predicate \
           and self.count == other.count


class _Value(FilterExpression):
    """Matches some value."""

    # target value.
    value: typing.Any

    def __init__(self, value: typing.Any):
        self.value = value

    def __repr__(self) -> str:
        return f'{typename(self)}({self.value})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.value))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) and self.value == other.value


class Is(_Value):
    """Match the URI of a node."""


class Equals(_Value):
    """Value matches exactly.
    NOTE: Value must correspond to literal type.
    """


class Substring(_Value):
    """Value matches a substring
    NOTE: value must be a string.
    """


class StartsWith(_Value):
    """Value begins with a given string."""


class EndsWith(_Value):
    """Value ends with a given string."""


class Distance(FilterExpression):
    """Distance to a reference is (strictly) below a threshold. Assumes a Feature literal."""

    # FIXME:
    # (a) pass a node/predicate as anchor instead of a value.
    #     Then we don't need to materialize the reference.
    # (b) pass a FilterExpression (_Bounded) instead of a threshold.
    #     Then, we could also query values greater than a threshold.

    # reference value.
    reference: typing.Any

    # distance threshold.
    threshold: float

    # closed (True) or open (False) bound.
    strict: bool

    def __init__(
            self,
            reference: typing.Any,
            threshold: float,
            strict: bool = False,
            ):
        self.reference = reference
        self.threshold = float(threshold)
        self.strict = bool(strict)

    def __repr__(self) -> str:
        return f'{typename(self)}({self.reference}, {self.threshold}, {self.strict})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), tuple(self.reference), self.threshold, self.strict))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) \
           and self.reference == other.reference \
           and self.threshold == other.threshold \
           and self.strict == other.strict


class _Bounded(FilterExpression):
    """Value is bounded by a threshold. Assumes a Number literal."""

    # bound.
    threshold: float

    # closed (True) or open (False) bound.
    strict: bool

    def __init__(
            self,
            threshold: float,
            strict: bool = True,
            ):
        self.threshold = float(threshold)
        self.strict = bool(strict)

    def __repr__(self) -> str:
        return f'{typename(self)}({self.threshold}, {self.strict})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.threshold, self.strict))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) \
           and self.threshold == other.threshold \
           and self.strict == other.strict



class LessThan(_Bounded):
    """Value is (strictly) smaller than threshold. Assumes a Number literal."""


class GreaterThan(_Bounded):
    """Value is (strictly) larger than threshold. Assumes a Number literal."""


class Predicate(PredicateExpression):
    """A single predicate."""

    # predicate URI
    predicate: URI

    # reverse the predicate's direction
    reverse: bool

    def __init__(
            self,
            predicate: URI,
            reverse: typing.Optional[bool] = False,
            ):
        # check arguments
        if not isinstance(predicate, URI):
            raise TypeError(predicate)
        # assign members
        self.predicate = predicate
        self.reverse = bool(reverse)

    def __repr__(self) -> str:
        return f'{typename(self)}({self.predicate}, {self.reverse})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.predicate, self.reverse))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) \
           and self.predicate == other.predicate \
           and self.reverse == other.reverse


class OneOf(PredicateExpression, abc.Collection):
    """A set of predicate alternatives.

    The predicates' domains must be ascendants or descendants of each other.
    The overall domain is the most specific one.

    The predicate's domains must be ascendants or descendants of each other.
    The overall range is the most generic one.
    """

    # predicate alternatives
    expr: typing.Set[PredicateExpression]

    def __init__(self, *expr: typing.Union[PredicateExpression, URI]):
        # unfold arguments
        unfolded = set(normalize_args(*expr)) # type: ignore [arg-type] # this is getting too complex...
        # check arguments
        if len(unfolded) == 0:
            raise AttributeError('expected at least one expression, found none')
        # ensure PredicateExpression
        unfolded = {Predicate(e) if isinstance(e, URI) else e for e in unfolded}
        # check type
        if not all(isinstance(e, PredicateExpression) for e in unfolded):
            raise TypeError(expr)
        # assign member
        self.expr = unfolded

    def __contains__(self, expr: typing.Any) -> bool:
        """Return True if *expr* is among the child expressions."""
        return expr in self.expr

    def __iter__(self) -> typing.Iterator[PredicateExpression]:
        """Iterator over child expressions."""
        return iter(self.expr)

    def __len__(self) -> int:
        """Number of child expressions."""
        return len(self.expr)

    def __repr__(self) -> str:
        return f'{typename(self)}({self.expr})'

    def __hash__(self) -> int:
        return hash((super().__hash__(), tuple(sorted(self.expr, key=repr))))

    def __eq__(self, other) -> bool:
        return super().__eq__(other) and self.expr == other.expr


# Helpers
# invalid-name is disabled since they explicitly mimic an expression

def IsIn(*values) -> FilterExpression: # pylint: disable=invalid-name
    """Match any of the given URIs."""
    args = normalize_args(*values)
    if len(args) == 0:
        raise AttributeError('expected at least one value, found none')
    if len(args) == 1:
        return Is(args[0])
    return Or(Is(value) for value in args)

def IsNotIn(*values) -> FilterExpression: # pylint: disable=invalid-name
    """Match none of the given URIs."""
    return Not(IsIn(*values))


def Between( # pylint: disable=invalid-name
        lo: float = float('-inf'),
        hi: float = float('inf'),
        lo_strict: bool = True,
        hi_strict: bool = True,
        ) -> FilterExpression :
    """Match numerical values between *lo* and *hi*. Include bounds if strict is False."""
    if abs(lo) == hi == float('inf'):
        raise ValueError('range cannot be INF on both sides')
    if lo > hi:
        raise ValueError(f'lower bound ({lo}) cannot be less than upper bound ({hi})')
    if lo == hi and not lo_strict and not hi_strict:
        return Equals(lo)
    if lo == hi: # either bound is strict
        raise ValueError('bounds cannot be equal when either is strict')
    if lo != float('-inf') and hi != float('inf'):
        return And(GreaterThan(lo, lo_strict), LessThan(hi, hi_strict))
    if lo != float('-inf'):
        return GreaterThan(lo, lo_strict)
    # hi != float('inf'):
    return LessThan(hi, hi_strict)


def Includes(*values, approx: bool = False) -> FilterExpression: # pylint: disable=invalid-name
    """Match any of the given *values*. Uses `Substring` if *approx* is set."""
    args = normalize_args(*values)
    cls = Substring if approx else Equals
    if len(args) == 0:
        raise AttributeError('expected at least one value, found none')
    if len(args) == 1:
        return cls(args[0])
    return Or(cls(v) for v in args)


def Excludes(*values, approx: bool = False) -> FilterExpression: # pylint: disable=invalid-name
    """Match none of the given *values*. Uses `Substring` if *approx* is set."""
    args = normalize_args(*values)
    cls = Substring if approx else Equals
    if len(args) == 0:
        raise AttributeError('expected at least one value, found none')
    if len(args) == 1:
        return Not(cls(args[0]))
    return Not(Or(cls(v) for v in args))


## EOF ##
