"""Spatial color features.
"""
# standard imports
import typing

# external imports
import PIL.Image
import numpy as np

# bsie imports
from bsie.utils import bsfs, node, ns

# inner-module imports
from .. import base

# constants
FEATURE_NAME = ns.bsf.ColorsSpatial()

# exports
__all__: typing.Sequence[str] = (
    'ColorsSpatial',
    )


## code ##

class ColorsSpatial(base.Extractor):
    """Determine dominant colors of subregions in the image.

    Computes the domiant color of increasingly smaller subregions of the image.
    """

    CONTENT_READER = 'bsie.reader.image.Image'

    # Initial subregion width.
    width: int

    # Initial subregion height.
    height: int

    # Decrement exponent.
    exp: float

    # Principal predicate's URI.
    _predicate_name: bsfs.URI

    def __init__(
            self,
            width: int = 32,
            height: int = 32,
            exp: float = 4.,
            ):
        # instance identifier
        uuid = bsfs.uuid.UCID.from_dict({
            'width': width,
            'height': height,
            'exp': exp,
            })
        # determine symbol names
        instance_name = getattr(FEATURE_NAME, uuid)
        predicate_name = getattr(ns.bse, 'colors_spatial_' + uuid)
        # get vector dimension
        dimension = self.dimension(width, height, exp)
        # initialize parent with the schema
        super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + f'''
            <{FEATURE_NAME}> rdfs:subClassOf bsa:Feature ;
                # annotations
                rdfs:label "Spatially dominant colors"^^xsd:string ;
                schema:description "Domiant colors of subregions in an image."^^xsd:string ;
                bsfs:distance <https://schema.bsfs.io/core/distance#euclidean> ;
                bsfs:dtype xsd:integer .

            <{instance_name}> rdfs:subClassOf <{FEATURE_NAME}> ;
                bsfs:dimension "{dimension}"^^xsd:integer ;
                # annotations
                <{FEATURE_NAME}/args#width> "{width}"^^xsd:integer ;
                <{FEATURE_NAME}/args#height> "{height}"^^xsd:integer ;
                <{FEATURE_NAME}/args#exp> "{exp}"^^xsd:float .

            <{predicate_name}> rdfs:subClassOf bsfs:Predicate ;
                rdfs:domain bsn:Entity ;
                rdfs:range <{instance_name}> ;
                bsfs:unique "true"^^xsd:boolean .

            '''))
        # assign extra members
        self.width = width
        self.height = height
        self.exp = exp
        self._predicate_name = predicate_name

    def __repr__(self) -> str:
        return f'{bsfs.typename(self)}({self.width}, {self.height}, {self.exp})'

    def __eq__(self, other: typing.Any) -> bool:
        return super().__eq__(other) \
          and self.width == other.width \
          and self.height == other.height \
          and self.exp == other.exp

    def __hash__(self) -> int:
        return hash((super().__hash__(), self.width, self.height, self.exp))

    @staticmethod
    def dimension(width: int, height: int, exp: float) -> int:
        """Return the feature vector dimension."""
        # FIXME: replace with a proper formula
        dim = 0
        while width >= 1 and height >= 1:
            dim += width * height
            width = np.floor(width / exp)
            height = np.floor(height / exp)
        dim *= 3 # per band
        return int(dim)

    def extract(
            self,
            subject: node.Node,
            content: PIL.Image.Image,
            principals: typing.Iterable[bsfs.schema.Predicate],
            ) -> typing.Iterator[typing.Tuple[node.Node, bsfs.schema.Predicate, typing.Any]]:
        # check principals
        if self.schema.predicate(self._predicate_name) not in principals:
            # nothing to do; abort
            return

        # convert to HSV
        content = content.convert('HSV')

        # get dimensions
        width, height = self.width, self.height
        num_bands = len(content.getbands()) # it's three since we converted to HSV before

        features = []
        while width >= 1 and height >= 1:
            # downsample
            img = content.resize((width, height), resample=PIL.Image.Resampling.BOX)
            # feature vector
            features.append(
                np.array(img.getdata()).reshape((width * height, num_bands)))
            # iterate
            width = int(np.floor(width / self.exp))
            height = int(np.floor(height / self.exp))

        # combine bands and convert features to tuple
        value = tuple(np.vstack(features).reshape(-1))
        # return triple with feature vector as value
        yield subject, self.schema.predicate(self._predicate_name), value

## EOF ##
