#  Copyright (c) 2020, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from collections import namedtuple
from typing import Optional, Union

import numpy as _np
import numpy as np
import sympy as sm

import coremltools.converters.mil.backend.mil.helper as mil_helper
import coremltools.proto.MIL_pb2 as _mil_pm

from .get_type_info import get_type_info
from .type_bool import bool as types_bool
from .type_bool import is_bool
from .type_complex import complex64 as types_complex64
from .type_complex import complex128 as types_complex128
from .type_complex import is_complex
from .type_dict import is_dict
from .type_double import fp16 as types_fp16
from .type_double import fp32 as types_fp32
from .type_double import fp64 as types_fp64
from .type_double import is_float
from .type_int import SUB_BYTE_DTYPE_METADATA_KEY
from .type_int import int4 as types_int4
from .type_int import int8 as types_int8
from .type_int import int16 as types_int16
from .type_int import int32 as types_int32
from .type_int import int64 as types_int64
from .type_int import (
    is_int,
    is_sub_byte,
    np_int4_dtype,
    np_uint1_dtype,
    np_uint2_dtype,
    np_uint3_dtype,
    np_uint4_dtype,
    np_uint6_dtype,
)
from .type_int import uint1 as types_uint1
from .type_int import uint2 as types_uint2
from .type_int import uint3 as types_uint3
from .type_int import uint4 as types_uint4
from .type_int import uint6 as types_uint6
from .type_int import uint8 as types_uint8
from .type_int import uint16 as types_uint16
from .type_int import uint32 as types_uint32
from .type_int import uint64 as types_uint64
from .type_list import is_list
from .type_str import str as types_str
from .type_unknown import unknown

_TYPES_TO_NPTYPES = {
    types_bool: np.bool_,
    types_int4: np_int4_dtype,
    types_int8: np.int8,
    types_int16: np.int16,
    types_int32: np.int32,
    types_int64: np.int64,
    types_uint1: np_uint1_dtype,
    types_uint2: np_uint2_dtype,
    types_uint3: np_uint3_dtype,
    types_uint4: np_uint4_dtype,
    types_uint6: np_uint6_dtype,
    types_uint8: np.uint8,
    types_uint16: np.uint16,
    types_uint32: np.uint32,
    types_uint64: np.uint64,
    types_fp16: np.float16,
    types_fp32: np.float32,
    types_fp64: np.float64,
    types_complex64: np.complex64,
    types_complex128: np.complex128,
    types_str: np.str_,
}

_NPTYPES_TO_STRINGS = {
    np.bool_: "bool",
    np.int8: "int8",
    np.int16: "int16",
    np.int32: "int32",
    np.int64: "int64",
    np.uint8: "uint8",
    np.uint16: "uint16",
    np.uint32: "uint32",
    np.uint64: "uint64",
    np.float16: "fp16",
    np.float32: "fp32",
    np.float64: "fp64",
    np.complex64: "complex64",
    np.complex128: "complex128",
    np.str_: "string",
}

_TYPES_TO_STRINGS = {
    types_bool: "bool",
    types_int4: "int4",
    types_int8: "int8",
    types_int16: "int16",
    types_int32: "int32",
    types_int64: "int64",
    types_uint1: "uint1",
    types_uint2: "uint2",
    types_uint3: "uint3",
    types_uint4: "uint4",
    types_uint6: "uint6",
    types_uint8: "uint8",
    types_uint16: "uint16",
    types_uint32: "uint32",
    types_uint64: "uint64",
    types_fp16: "fp16",
    types_fp32: "fp32",
    types_fp64: "fp64",
    types_complex64: "complex64",
    types_complex128: "complex128",
    types_str: "string",
}

_TYPES_TO_RESOLUTION = {
    types_bool: 1,
    types_int4: 1,
    types_int8: 1,
    types_uint1: 1,
    types_uint2: 1,
    types_uint3: 1,
    types_uint4: 1,
    types_uint6: 1,
    types_uint8: 1,
    types_int16: 1,
    types_uint16: 1,
    types_int32: 1,
    types_int64: 1,
    types_fp16: np.finfo(np.float16).resolution,
    types_fp32: np.finfo(np.float32).resolution,
    types_fp64: np.finfo(np.float64).resolution,
}

RangeTuple = namedtuple("RangeTuple", "low high")

_TYPES_TO_RANGE = {
    types_bool: RangeTuple(0, 1),
    types_int4: RangeTuple(np.iinfo(np.int8).min >> 4, np.iinfo(np.int8).max >> 4),
    types_int8: RangeTuple(np.iinfo(np.int8).min, np.iinfo(np.int8).max),
    types_uint1: RangeTuple(np.iinfo(np.uint8).min >> 7, np.iinfo(np.uint8).max >> 7),
    types_uint2: RangeTuple(np.iinfo(np.uint8).min >> 6, np.iinfo(np.uint8).max >> 6),
    types_uint3: RangeTuple(np.iinfo(np.uint8).min >> 5, np.iinfo(np.uint8).max >> 5),
    types_uint4: RangeTuple(np.iinfo(np.uint8).min >> 4, np.iinfo(np.uint8).max >> 4),
    types_uint6: RangeTuple(np.iinfo(np.uint8).min >> 2, np.iinfo(np.uint8).max >> 2),
    types_uint8: RangeTuple(np.iinfo(np.uint8).min, np.iinfo(np.uint8).max),
    types_int16: RangeTuple(np.iinfo(np.int16).min, np.iinfo(np.int16).max),
    types_uint16: RangeTuple(np.iinfo(np.uint16).min, np.iinfo(np.uint16).max),
    types_int32: RangeTuple(np.iinfo(np.int32).min, np.iinfo(np.int32).max),
    types_int64: RangeTuple(np.iinfo(np.int64).min, np.iinfo(np.int64).max),
    types_fp16: RangeTuple(np.finfo(np.float16).min, np.finfo(np.float16).max),
    types_fp32: RangeTuple(np.finfo(np.float32).min, np.finfo(np.float32).max),
    types_fp64: RangeTuple(np.finfo(np.float64).min, np.finfo(np.float64).max),
}

BUILTIN_TO_PROTO_TYPES = {
    # bool:
    types_bool: _mil_pm.BOOL,

    # fp
    types_fp16: _mil_pm.FLOAT16,
    types_fp32: _mil_pm.FLOAT32,
    types_fp64: _mil_pm.FLOAT64,

    # int
    types_uint1: _mil_pm.UINT1,
    types_uint2: _mil_pm.UINT2,
    types_uint3: _mil_pm.UINT3,
    types_uint4: _mil_pm.UINT4,
    types_uint6: _mil_pm.UINT6,
    types_uint8: _mil_pm.UINT8,
    types_int4: _mil_pm.INT4,
    types_int8: _mil_pm.INT8,

    types_uint16: _mil_pm.UINT16,
    types_int16: _mil_pm.INT16,

    types_uint32: _mil_pm.UINT32,
    types_int32: _mil_pm.INT32,

    types_uint64: _mil_pm.UINT64,
    types_int64: _mil_pm.INT64,

    # str
    types_str: _mil_pm.STRING,
}

def np_dtype_to_py_type(np_dtype):
    # Can't use dict, as hash(np.int32) != hash(val.dtype)
    if np_dtype in [np.int32, np.int64]:
        return int
    if np_dtype in [bool, np.bool_]:
        return bool
    if np_dtype in [np.float32, np.float64]:
        return float
    if np_dtype in [np.complex64, np.complex128]:
        return complex
    raise NotImplementedError('{} is not supported'.format(np_dtype))

PROTO_TO_BUILTIN_TYPE = {v: k for k, v in BUILTIN_TO_PROTO_TYPES.items()}
_STRINGS_TO_TYPES = {v: k for k, v in _TYPES_TO_STRINGS.items()}
_STRINGS_TO_NPTYPES = {v: k for k, v in _NPTYPES_TO_STRINGS.items()}
_STRINGS_TO_NPTYPES.update(
    {
        "int4": np_int4_dtype,
        "uint1": np_uint1_dtype,
        "uint2": np_uint2_dtype,
        "uint3": np_uint3_dtype,
        "uint4": np_uint4_dtype,
        "uint6": np_uint6_dtype,
    }
)

def string_to_builtin(s):
    """
    Given a str, return its corresponding builtin type.
    """
    return _STRINGS_TO_TYPES[s]


def builtin_to_string(builtin_type):
    """
    Given a builtin type, return its corresponding string representation.
    """
    if is_dict(builtin_type):
        return "dict"
    return _TYPES_TO_STRINGS[builtin_type]


def string_to_nptype(s: str):
    """
    Given a str, return its corresponding numpy type.
    """
    return _STRINGS_TO_NPTYPES[s]


def nptype_from_builtin(btype):
    """
    Given a builtin type, return its corresponding Numpy dtype.
    """
    return _TYPES_TO_NPTYPES[btype]


def builtin_to_resolution(builtin_type: type):
    """
    Given a builtin type, return its corresponding resolution.
    """
    return _TYPES_TO_RESOLUTION[builtin_type]


def builtin_to_range(builtin_type: type) -> RangeTuple:
    """
    Given a builtin type, return its corresponding range.
    """
    return _TYPES_TO_RANGE[builtin_type]

def promote_types(dtype1, dtype2):
    """
    Get the smallest type to which the given scalar types can be cast.

    Args:
        dtype1 (builtin):
        dtype2 (builtin):

    Returns:
        A builtin datatype or None.

    Examples:
        >>> promote_types(int32, int64)
            builtin('int64')

        >>> promote_types(fp16, fp32)
            builtin('fp32')

        >>> promote_types(fp16, int32)
            builtin('fp16')
    """
    nptype1 = nptype_from_builtin(dtype1)
    nptype2 = nptype_from_builtin(dtype2)
    # Circumvent the undesirable np type promotion:
    # >> np.promote_types(np.float32, np.int32)
    # dtype('float64')
    if np.issubdtype(nptype1, np.floating) and np.issubdtype(nptype2, np.signedinteger):
        nppromoted = nptype1
    elif np.issubdtype(nptype2, np.floating) and np.issubdtype(
        nptype1, np.signedinteger
    ):
        nppromoted = nptype2
    else:
        nppromoted = np.promote_types(nptype1, nptype2)
    return numpy_type_to_builtin_type(nppromoted)


def promote_dtypes(dtypes):
    """
    Get the smallest promoted dtype, to which all scalar dtypes (provided through dtypes list argument) can be casted.
    Args:
        List [dtype (builtin)]
    Returns:
        A builtin datatype or None.

    Examples:
        >>> promote_dtypes([int32, int64, int16])
            builtin('int64')

        >>> promote_dtypes([fp16, fp32, fp64])
            builtin('fp64')

        >>> promote_dtypes([fp16, int32, int64])
            builtin('fp16')

    """
    if not isinstance(dtypes, (list, tuple)) or len(dtypes) < 1:
        raise ValueError("dtypes needs to be a list/tuple of at least 1 element")

    # Deduplicate inputs to avoid redundant calculations.
    # Without dedup, too large input will cause maximum recursion depth exceeded error.
    dtypes = list(set(dtypes))

    if len(dtypes) == 1:
        return dtypes[0]

    return promote_types(dtypes[0], promote_dtypes(dtypes[1:]))


def is_primitive(btype):
    """
    Is the indicated builtin type a primitive?
    """
    return (
        btype is types_bool
        or btype is types_str
        or is_float(btype)
        or is_int(btype)
        or is_complex(btype)
    )


def is_scalar(btype):
    """
    Is the given builtin type a scalar integer, float, boolean or string?
    """
    return (
        is_bool(btype)
        or is_int(btype)
        or is_float(btype)
        or is_str(btype)
        or is_complex(btype)
    )


def is_tensor(tensor_type):
    if tensor_type is None:
        return False
    try:
        type_info = get_type_info(tensor_type).name
    except TypeError:
        return False
    return type_info == "tensor"


def is_str(t):
    if t is None:
        return False
    try:
        type_info = get_type_info(t).name
    except TypeError:
        return False
    return type_info == "str"


def is_tuple(t):
    if t is None:
        return False
    try:
        type_info = get_type_info(t).name
    except TypeError:
        return False
    return type_info == "tuple"


def is_dict(t):
    if t is None:
        return False
    try:
        type_info = get_type_info(t).name
    except TypeError:
        return False
    return type_info == "dict"


def is_builtin(t):
    return is_scalar(t) or is_tensor(t) or is_str(t) or is_tuple(t)


def _numpy_dtype_instance_to_builtin_type(np_dtype: np.dtype) -> Optional[type]:
    metadata_dict = np_dtype.metadata
    if metadata_dict is not None and SUB_BYTE_DTYPE_METADATA_KEY in metadata_dict:
        return metadata_dict[SUB_BYTE_DTYPE_METADATA_KEY]

    if np_dtype in _NPTYPES_TO_STRINGS:
        return string_to_builtin(_NPTYPES_TO_STRINGS[np_dtype])
    return None


def numpy_type_to_builtin_type(nptype) -> type:
    """
    Converts a numpy type to its builtin `types` equivalent.
    Supports Python native types and numpy types.
    """
    if isinstance(nptype, np.dtype):
        builtin_type = _numpy_dtype_instance_to_builtin_type(nptype)
        if builtin_type is not None:
            return builtin_type

    # If this is a data type object, use the corresponding scalar data type.
    if issubclass(type(nptype), np.dtype):
        nptype = nptype.type

    if issubclass(nptype, (bool, np.bool_)):
        # numpy as 2 bool types it looks like. what is the difference?
        return types_bool
    # Because np.uint is a subclass of int,
    # we need to first check for np.uint before
    # checking for int
    elif issubclass(nptype, np.uint8):
        return types_uint8
    elif issubclass(nptype, np.int8):
        return types_int8
    elif issubclass(nptype, np.uint16):
        return types_uint16
    elif issubclass(nptype, np.int16):
        return types_int16
    elif issubclass(nptype, np.uint32):
        return types_uint32
    elif issubclass(nptype, np.int32):
        return types_int32
    elif issubclass(nptype, np.uint64):
        return types_uint64
    elif issubclass(nptype, np.int64):
        return types_int64
    elif issubclass(nptype, int) or nptype == int:
        # Catch all int
        return types_int32
    elif issubclass(nptype, np.object_):
        # symbolic shape is considered int32
        return types_int32
    elif issubclass(nptype, np.float16):
        return types_fp16
    elif (
        issubclass(nptype, (np.float32, np.single)) or nptype == float
    ):
        return types_fp32
    elif issubclass(nptype, (np.float64, np.double)):
        return types_fp64
    elif issubclass(nptype, np.complex64):
        return types_complex64
    elif issubclass(nptype, (np.complex128, complex)):
        return types_complex128
    elif issubclass(nptype, (str, np.bytes_, np.str_)):
        return types_str
    else:
        raise TypeError(f"Unsupported numpy type: {nptype}.")


# Tries to get the equivalent builtin type of a
# numpy or python type.
def type_to_builtin_type(type):
    # Infer from numpy type if it is one
    if type.__module__ == np.__name__:
        return numpy_type_to_builtin_type(type)

    # Otherwise, try to infer from a few generic python types
    if issubclass(type, bool):
        return types_bool
    elif issubclass(type, int):
        return types_int32
    elif issubclass(type, str):
        return types_str
    elif issubclass(type, float):
        return types_fp32
    elif issubclass(type, complex):
        return types_complex64
    else:
        raise TypeError("Could not determine builtin type for " + str(type))


def numpy_val_to_builtin_val(npval):
    if np.isscalar(npval):
        ret_type = type_to_builtin_type(type(npval))
        ret = ret_type()
        ret.val = npval
        return ret, ret_type
    else:
        builtintype = numpy_type_to_builtin_type(npval.dtype)
        from . import tensor as types_tensor

        ret_type = types_tensor(builtintype, npval.shape)
        ret = ret_type()
        ret.val = npval
        return ret, ret_type


def is_subtype_tensor(type1, type2):
    # requires primitive types match
    if type1.get_primitive() != type2.get_primitive():
        return False

    shape1 = type1.get_shape()
    shape2 = type2.get_shape()
    # Same rank
    if len(shape1) != len(shape2):
        return False

    for d1, d2 in zip(shape1, shape2):
        if d1 == d2:
            continue

        # tensor with shape (3, s0) is not a subtype of tensor with shape (3,
        # 1), but is a subtype of tensor with shape (3, s1)
        d1_is_symbolic = issubclass(type(d1), sm.Basic)
        d2_is_symbolic = issubclass(type(d2), sm.Basic)
        if d1_is_symbolic and d2_is_symbolic:
            continue
        if d1_is_symbolic and not d2_is_symbolic:
            return False
        if not d1_is_symbolic and not d2_is_symbolic and d1 != d2:
            return False
    return True


def is_subtype(type1, type2):
    """
    Return True if type1 is a subtype of type2. False otherwise.
    """
    if type2 == unknown:
        return True  # any class is a subclass of unknown (None) type.
    if is_list(type2):
        return is_list(type1) and is_subtype(type1.T[0], type2.T[0])
    if is_tensor(type1) and is_tensor(type2):
        return is_subtype_tensor(type1, type2)
    return type1 == type2


def _numpy_val_to_bytes(val: Union[np.ndarray, np.generic]) -> bytes:
    # Import here to avoid circular import.
    from coremltools.optimize.coreml import _utils as optimize_utils

    builtin_type = numpy_type_to_builtin_type(val.dtype)
    if is_sub_byte(builtin_type):
        val = optimize_utils.pack_elements_into_bits(val, builtin_type.get_bitwidth())

    return val.tobytes()

def np_val_to_py_type(val):
    """Convert numpy val to python primitive equivalent. Ex:

    Given: val = np.array([True, False])
    Returns: (True, False)

    Given: val = np.array(32, dtype=np.int32)
    Returns 32
    """
    if not isinstance(val, (_np.ndarray, _np.generic)):
        return val

    builtin_type = numpy_type_to_builtin_type(val.dtype)
    if builtin_type in mil_helper.IMMEDIATE_VALUE_TYPES_IN_BYTES:
        return _numpy_val_to_bytes(val)
    else:
        if val.dtype in (_np.uint16, _np.int16):
            # TODO (rdar://111797203): Serialize to byte after MIL changes to read from byte field.
            val = val.astype(np.int32)
        is_np_scalar = isinstance(val, _np.generic) or val.shape == ()
        py_type = np_dtype_to_py_type(val.dtype)
        return py_type(val) if is_np_scalar else tuple(py_type(v) for v in val.flatten())


def infer_complex_dtype(real_dtype, imag_dtype):
    """Infers the complex dtype from real and imaginary part's dtypes."""
    promoted_dtype = promote_types(real_dtype, imag_dtype)
    if promoted_dtype == types_fp32:
        return types_complex64
    elif promoted_dtype == types_fp64:
        return types_complex128
    else:
        raise ValueError(
            f"Unsupported real/imag dtype ({real_dtype}/{imag_dtype}) to construct a "
            f"complex dtype."
        )


def infer_fp_dtype_from_complex(complex_dtype):
    """Infers the fp dtype of real and imaginary part from the complex dtype."""
    if complex_dtype == types_complex64:
        return types_fp32
    elif complex_dtype == types_complex128:
        return types_fp64
    else:
        raise ValueError(f"Unsupported complex dtype ({complex_dtype}).")


def get_nbits_int_builtin_type(nbits: int, signed: True) -> type:
    """Get the nbits int built-in type."""
    type_prefix = "u" if not signed else ""
    return string_to_builtin(f"{type_prefix}int{nbits}")
