#!/usr/bin/env python3
#
"""Convert image to NXT data source file."""
#
# Copyright (C) 2024 Nicolas Schodet
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
import argparse
import os
import os.path
import struct
import sys

import tomllib
from PIL import Image, ImageChops


def print_hex(indent, bytes_or_fmt, *args, wrap=8, file=None):
    """Print data as hexadecimal, use struct format."""
    if not args:
        data = bytes_or_fmt
    else:
        data = struct.pack(bytes_or_fmt, *args)
    for i in range(0, len(data), wrap):
        line = ", ".join(f"{b:#04x}" for b in data[i : i + wrap])
        print(indent + line + ",", file=file)


def data_from_image(img):
    """Create data in NXT bitmap format from image.

    Pixels in NXT are organized as packets of 8 pixel high columns. PIL does not write
    this format, so just transpose the image then transpose every bytes.
    """
    if img.height % 8 != 0:
        raise RuntimeError("Height must be multiple of 8")
    # Transpose.
    img = img.transpose(Image.Transpose.TRANSPOSE)
    # Extract data
    data = img.tobytes("raw", "1;IR")
    # Split in lines (corresponding to columns).
    data = [data[i : i + img.width // 8] for i in range(0, len(data), img.width // 8)]
    # Transpose.
    data = zip(*data)
    # Paste everything together.
    data = bytes(val for band in data for val in band)
    return data


def crop_image(img):
    """Crop borders to reduce image size, return cropped image and x and y offsets as
    a tuple.
    """
    print(img.width, img.height)
    imginv = ImageChops.invert(img)
    bbox = imginv.getbbox()
    if bbox is None:
        # Special case for Test2 image.
        return img, (0, 0)
    left, upper, right, lower = bbox
    print(left, upper, right, lower)
    # Round down to multiple of 8.
    upper = upper // 8 * 8
    lower = (lower + 7) // 8 * 8
    return img.crop((left, upper, right, lower)), (left, upper)


def convert_bitmap(info, img_file, out_file, crop=False):
    """Convert to BMPMAP format."""
    img = Image.open(img_file)
    if crop:
        img, (crop_x, crop_y) = crop_image(img)
    else:
        crop_x, crop_y = 0, 0
    data = data_from_image(img)
    start_x = info["start_x"]
    start_y = info["start_y"]
    basename = os.path.basename(os.path.splitext(img_file)[0])
    with open(out_file, "w") as f:
        print(f"#define {basename}_size {len(data)+8}", file=f)
        print(f"const BMPMAP {basename} =", file=f)
        print("{", file=f)
        print_hex("  ", ">H", 0x0200, file=f)
        print_hex("  ", ">H", len(data), file=f)
        print_hex("  ", "B", start_x + crop_x, file=f)
        print_hex("  ", "B", start_y + crop_y, file=f)
        print_hex("  ", "B", img.width, file=f)
        print_hex("  ", "B", img.height, file=f)
        print("  {", file=f)
        print_hex("    ", data, file=f)
        print("  }", file=f)
        print("};", file=f)


def convert_icon(info, img_file, out_file):
    """Convert to ICON format."""
    img = Image.open(img_file)
    data = data_from_image(img)
    item_pixels_x = info["item_pixels_x"]
    item_pixels_y = info["item_pixels_y"]
    basename = os.path.basename(os.path.splitext(out_file)[0])
    with open(out_file, "w") as f:
        print(f"const ICON {basename} =", file=f)
        print("{", file=f)
        print_hex("  ", ">H", 0x0400, file=f)
        print_hex("  ", ">H", len(data), file=f)
        print_hex("  ", "B", img.width // item_pixels_x, file=f)
        print_hex("  ", "B", img.height // item_pixels_y, file=f)
        print_hex("  ", "B", item_pixels_x, file=f)
        print_hex("  ", "B", item_pixels_y, file=f)
        print("  {", file=f)
        print_hex("    ", data, file=f)
        print("  }", file=f)
        print("};", file=f)


p = argparse.ArgumentParser(description=__doc__)
p.add_argument("info", help="input TOML file")
p.add_argument("image", help="input image")
p.add_argument("-o", "--output", metavar="FILE", help="output header file")
options = p.parse_args()

try:
    with open(options.info, "rb") as f:
        info = tomllib.load(f)

    if info["format"] == "bitmap":
        convert_bitmap(info, options.image, options.output)
    elif info["format"] == "icon":
        convert_icon(info, options.image, options.output)
    else:
        raise RuntimeError("Unknown format")
except Exception as e:
    try:
        os.remove(options.output)
    except FileNotFoundError:
        pass
    print(e, file=sys.stderr)
    sys.exit(1)
