
# imports
import typing

# bsfs imports
from bsfs.namespace import ns
from bsfs.utils import errors, URI, typename

# exports
__all__: typing.Sequence[str] = (
    'Literal',
    'Node',
    'Predicate',
    'Feature',
    )


## code ##

class _Type():
    """A class is defined via its uri.

    Classes define a partial order.
    The order operators indicate whether some class is a
    superclass (greater-than) or a subclass (less-than) of another.
    Comparisons are only supported within the same type.

    For example, consider the class hierarchy below:

    Vehicle
      Two-wheel
        Bike
        Bicycle

    >>> vehicle = _Type('Vehicle')
    >>> twowheel = _Type('Two-wheel', vehicle)
    >>> bike = _Type('Bike', twowheel)
    >>> bicycle = _Type('Bicycle', twowheel)

    Two-wheel is equivalent to itself
    >>> twowheel == vehicle
    False
    >>> twowheel == twowheel
    True
    >>> twowheel == bicycle
    False

    Two-wheel is a true subclass of Vehicle
    >>> twowheel < vehicle
    True
    >>> twowheel < twowheel
    False
    >>> twowheel < bicycle
    False

    Two-wheel is a subclass of itself and Vehicle
    >>> twowheel <= vehicle
    True
    >>> twowheel <= twowheel
    True
    >>> twowheel <= bicycle
    False

    Two-wheel is a true superclass of Bicycle
    >>> twowheel > vehicle
    False
    >>> twowheel > twowheel
    False
    >>> twowheel > bicycle
    True

    Two-wheel is a superclass of itself and Bicycle
    >>> twowheel >= vehicle
    False
    >>> twowheel >= twowheel
    True
    >>> twowheel >= bicycle
    True

    Analoguous to sets, this is not a total order:
    >>> bike < bicycle
    False
    >>> bike > bicycle
    False
    >>> bike == bicycle
    False
    """

    # class uri.
    uri: URI

    # parent's class uris.
    parent: typing.Optional['_Type'] # TODO: for python >=3.11: use typing.Self

    def __init__(
            self,
            uri: URI,
            parent: typing.Optional['_Type'] = None,
            **annotations: typing.Any,
            ):
        self.uri = URI(uri)
        self.parent = parent
        self.annotations = annotations

    def parents(self) -> typing.Generator['_Type', None, None]:
        """Generate a list of parent nodes."""
        curr = self.parent
        while curr is not None:
            yield curr
            curr = curr.parent

    def child(
            self,
            uri: URI,
            **kwargs,
            ):
        """Return a child of the current class."""
        return type(self)(
            uri=uri,
            parent=self,
            **kwargs
            )

    def __str__(self) -> str:
        return f'{typename(self)}({self.uri})'

    def __repr__(self) -> str:
        return f'{typename(self)}({self.uri}, {repr(self.parent)})'

    def __hash__(self) -> int:
        return hash((type(self), self.uri, self.parent))

    # NOTE: For equality and order functions (lt, gt, le, ge) we explicitly want type equality!
    # Consider the statements below, with class Vehicle(_Type) and class TwoWheel(Vehicle):
    # * Vehicle('foo', None) == TwoWheel('foo', None): Instances of different types cannot be equivalent.
    # * Vehicle('foo', None) <= TwoWheel('foo', None): Cannot compare the different types Vehicles and TwoWheel.

    def __eq__(self, other: typing.Any) -> bool:
        """Return True iff *self* is equivalent to *other*."""
        # pylint: disable=unidiomatic-typecheck
        return type(other) is type(self) \
           and self.uri == other.uri \
           and self.parent == other.parent


    def __lt__(self, other: typing.Any) -> bool:
        """Return True iff *self* is a true subclass of *other*."""
        if not isinstance(other, _Type):
            return NotImplemented
        if not isinstance(other, type(self)): # FIXME: necessary?
            return False
        if self.uri == other.uri: # equivalence
            return False
        if self in other.parents(): # superclass
            return False
        if other in self.parents(): # subclass
            return True
        # not related
        return False

    def __le__(self, other: typing.Any) -> bool:
        """Return True iff *self* is equivalent or a subclass of *other*."""
        if not isinstance(other, _Type):
            return NotImplemented
        if not isinstance(other, type(self)): # FIXME: necessary?
            return False
        if self.uri == other.uri: # equivalence
            return True
        if self in other.parents(): # superclass
            return False
        if other in self.parents(): # subclass
            return True
        # not related
        return False

    def __gt__(self, other: typing.Any) -> bool:
        """Return True iff *self* is a true superclass of *other*."""
        if not isinstance(other, _Type):
            return NotImplemented
        if not isinstance(other, type(self)): # FIXME: necessary?
            return False
        if self.uri == other.uri: # equivalence
            return False
        if self in other.parents(): # superclass
            return True
        if other in self.parents(): # subclass
            return False
        # not related
        return False

    def __ge__(self, other: typing.Any) -> bool:
        """Return True iff *self* is eqiuvalent or a superclass of *other*."""
        if not isinstance(other, _Type):
            return NotImplemented
        if not isinstance(other, type(self)): # FIXME: necessary?
            return False
        if self.uri == other.uri: # equivalence
            return True
        if self in other.parents(): # superclass
            return True
        if other in self.parents(): # subclass
            return False
        # not related
        return False


class Vertex(_Type):
    """Graph vertex types. Can be a Node or a Literal."""
    parent: typing.Optional['Vertex']
    def __init__(self, uri: URI, parent: typing.Optional['Vertex'], **kwargs):
        super().__init__(uri, parent, **kwargs)


class Node(Vertex):
    """Node type."""
    parent: typing.Optional['Node']
    def __init__(self, uri: URI, parent: typing.Optional['Node'], **kwargs):
        super().__init__(uri, parent, **kwargs)


class Literal(Vertex):
    """Literal type."""
    parent: typing.Optional['Literal']
    def __init__(self, uri: URI, parent: typing.Optional['Literal'], **kwargs):
        super().__init__(uri, parent, **kwargs)


class Feature(Literal):
    """Feature type."""

    # Number of feature vector dimensions.
    dimension: int

    # Feature vector datatype.
    dtype: URI

    # Distance measure to compare feature vectors.
    distance: URI

    def __init__(
            self,
            # Type members
            uri: URI,
            parent: typing.Optional[Literal],
            # Feature members
            dimension: int,
            dtype: URI,
            distance: URI,
            **kwargs,
            ):
        super().__init__(uri, parent, **kwargs)
        self.dimension = int(dimension)
        self.dtype = URI(dtype)
        self.distance = URI(distance)

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.dimension, self.dtype, self.distance))

    def __eq__(self, other: typing.Any) -> bool:
        return super().__eq__(other) \
           and self.dimension == other.dimension \
           and self.dtype == other.dtype \
           and self.distance == other.distance

    def child(
            self,
            uri: URI,
            dimension: typing.Optional[int] = None,
            dtype: typing.Optional[URI] = None,
            distance: typing.Optional[URI] = None,
            **kwargs,
            ):
        """Return a child of the current class."""
        if dimension is None:
            dimension = self.dimension
        if dtype is None:
            dtype = self.dtype
        if distance is None:
            distance = self.distance
        return super().child(
            uri=uri,
            dimension=dimension,
            dtype=dtype,
            distance=distance,
            **kwargs,
            )

class Predicate(_Type):
    """Predicate base type."""

    # source type.
    domain: Node

    # destination type.
    range: Vertex

    # maximum cardinality of type.
    unique: bool

    def __init__(
            self,
            # Type members
            uri: URI,
            parent: typing.Optional['Predicate'],
            # Predicate members
            domain: Node,
            range: Vertex, # pylint: disable=redefined-builtin
            unique: bool,
            **kwargs,
            ):
        # check arguments
        if not isinstance(domain, Node):
            raise TypeError(domain)
        if range != ROOT_VERTEX and not isinstance(range, (Node, Literal)):
            raise TypeError(range)
        # initialize
        super().__init__(uri, parent, **kwargs)
        self.domain = domain
        self.range = range
        self.unique = bool(unique)

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.domain, self.unique, self.range))

    def __eq__(self, other: typing.Any) -> bool:
        return super().__eq__(other) \
           and self.domain == other.domain \
           and self.range == other.range \
           and self.unique == other.unique

    def child(
            self,
            uri: URI,
            domain: typing.Optional[Node] = None,
            range: typing.Optional[Vertex] = None, # pylint: disable=redefined-builtin
            unique: typing.Optional[bool] = None,
            **kwargs,
            ):
        """Return a child of the current class."""
        if domain is None:
            domain = self.domain
        if not domain <= self.domain:
            raise errors.ConsistencyError(f'{domain} must be a subclass of {self.domain}')
        if range is None:
            range = self.range
        # NOTE: The root predicate has a Vertex as range, which is neither a parent of the root
        # Node nor Literal. Hence, that test is skipped since a child should be allowed to
        # specialize from Vertex to anything.
        if self.range != ROOT_VERTEX and not range <= self.range:
            raise errors.ConsistencyError(f'{range} must be a subclass of {self.range}')
        if unique is None:
            unique = self.unique
        return super().child(
            uri=uri,
            domain=domain,
            range=range,
            unique=unique,
            **kwargs
            )


# essential vertices
ROOT_VERTEX = Vertex(
    uri=ns.bsfs.Vertex,
    parent=None,
    )

ROOT_NODE = Node(
    uri=ns.bsfs.Node,
    parent=None,
    )

ROOT_LITERAL = Literal(
    uri=ns.bsfs.Literal,
    parent=None,
    )

ROOT_BLOB = Literal(
    uri=ns.bsl.BinaryBlob,
    parent=ROOT_LITERAL,
    )

ROOT_NUMBER = Literal(
    uri=ns.bsl.Number,
    parent=ROOT_LITERAL,
    )

ROOT_TIME = Literal(
    uri=ns.bsl.Time,
    parent=ROOT_LITERAL,
    )

ROOT_ARRAY = Literal(
    uri=ns.bsl.Array,
    parent=ROOT_LITERAL,
    )

ROOT_FEATURE = Feature(
    uri=ns.bsl.Array.Feature,
    parent=ROOT_ARRAY,
    dimension=1,
    dtype=ns.bsfs.dtype().f16,
    distance=ns.bsd.euclidean,
    )

# essential predicates
ROOT_PREDICATE = Predicate(
    uri=ns.bsfs.Predicate,
    parent=None,
    domain=ROOT_NODE,
    range=ROOT_VERTEX,
    unique=False,
    )

## EOF ##
