#!/usr/bin/env python3
"""
process_receipt.py — OCR pipeline para boletas/tickets
Uso: python3 process_receipt.py <ruta_imagen> <categorias_json> [patrones_json]
Salida: JSON con texto, monto, fecha, proveedor, categoria_id + scores de confianza (0-1)
"""

import sys, os, json, re
from datetime import datetime

try:
    from PIL import Image, ImageFilter, ImageEnhance, ImageOps
    import pytesseract
    import cv2
    import numpy as np
except ImportError as e:
    print(json.dumps({"error": f"Dependencia faltante: {e}"}))
    sys.exit(1)

# ── MESES ─────────────────────────────────────────────────────────────────────
MESES = {
    'enero':'01','febrero':'02','marzo':'03','abril':'04','mayo':'05','junio':'06',
    'julio':'07','agosto':'08','septiembre':'09','octubre':'10','noviembre':'11','diciembre':'12',
    'jan':'01','feb':'02','mar':'03','apr':'04','may':'05','jun':'06',
    'jul':'07','aug':'08','sep':'09','oct':'10','nov':'11','dec':'12',
}

# ── PREPROCESAMIENTO ──────────────────────────────────────────────────────────

def _corregir_exif(img: Image.Image) -> Image.Image:
    """Rota la imagen según metadatos EXIF para que quede derecha."""
    try:
        return ImageOps.exif_transpose(img)
    except Exception:
        return img


def _detectar_boleta(cv_img):
    """
    Detecta el rectángulo de la boleta en la imagen y aplica perspectiva
    para obtener solo el papel plano. Si no se detecta con confianza,
    devuelve la imagen original sin recortar.
    """
    h, w = cv_img.shape[:2]

    # Reducir para procesamiento rápido del contorno (la warp usa la original)
    escala = 1200 / max(h, w)
    if escala < 1:
        small = cv2.resize(cv_img, (int(w * escala), int(h * escala)))
    else:
        small = cv_img.copy()
        escala = 1.0

    sh, sw = small.shape[:2]

    # Convertir a gris y detectar bordes
    gray_s = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
    blur   = cv2.GaussianBlur(gray_s, (5, 5), 0)
    edges  = cv2.Canny(blur, 30, 100)
    edges  = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=2)

    contornos, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contornos    = sorted(contornos, key=cv2.contourArea, reverse=True)

    quad = None
    for cnt in contornos[:8]:
        peri   = cv2.arcLength(cnt, True)
        approx = cv2.approxPolyDP(cnt, 0.02 * peri, True)
        area   = cv2.contourArea(cnt)
        # Debe ser un cuadrilátero que ocupe al menos 20 % de la imagen
        if len(approx) == 4 and area > sw * sh * 0.20:
            quad = approx
            break

    if quad is None:
        return cv_img  # sin boleta detectada, usar imagen completa

    # Escalar puntos a coordenadas originales
    pts = (quad.reshape(4, 2) / escala).astype(np.float32)

    # Ordenar: tl, tr, br, bl
    s    = pts.sum(axis=1)
    diff = np.diff(pts, axis=1).flatten()
    tl   = pts[np.argmin(s)]
    br   = pts[np.argmax(s)]
    tr   = pts[np.argmin(diff)]
    bl   = pts[np.argmax(diff)]
    pts_ord = np.array([tl, tr, br, bl], dtype=np.float32)

    # Dimensiones del rectángulo destino
    ancho = int(max(
        np.linalg.norm(tr - tl),
        np.linalg.norm(br - bl)
    ))
    alto = int(max(
        np.linalg.norm(bl - tl),
        np.linalg.norm(br - tr)
    ))

    if ancho < 100 or alto < 100:
        return cv_img

    dst = np.array([[0, 0], [ancho, 0], [ancho, alto], [0, alto]], dtype=np.float32)
    M   = cv2.getPerspectiveTransform(pts_ord, dst)
    return cv2.warpPerspective(cv_img, M, (ancho, alto))


def _mejorar_para_ocr(cv_gray):
    """
    Aplica CLAHE + umbralización adaptativa para obtener texto negro sobre
    blanco limpio, ideal para Tesseract. CLAHE normaliza el contraste local
    antes del umbral, lo que mejora imágenes con iluminación desigual o bajo
    contraste (p.ej. tickets de peaje, papel térmico desgastado).
    """
    # Escalar a mínimo 1800px de alto para garantizar buena resolución OCR
    h, w = cv_gray.shape[:2]
    if h < 1800:
        escala = 1800 / h
        cv_gray = cv2.resize(cv_gray, (int(w * escala), int(h * escala)), interpolation=cv2.INTER_CUBIC)

    # CLAHE: mejora contraste local sin sobreexponer zonas ya claras
    clahe   = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))
    cv_gray = clahe.apply(cv_gray)

    # Umbral adaptativo: maneja iluminación desigual (sombras en la boleta)
    blur   = cv2.GaussianBlur(cv_gray, (3, 3), 0)
    thresh = cv2.adaptiveThreshold(
        blur, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY,
        blockSize=31,
        C=15
    )
    # Leve sharpening del resultado
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    return cv2.filter2D(thresh, -1, kernel)


def preprocess(img_path: str):
    """
    Pipeline completo:
    1. Carga y corrige orientación EXIF
    2. Detecta y recorta el rectángulo de la boleta
    3. Convierte a B&N con umbral adaptativo
    Devuelve imagen PIL lista para Tesseract.
    """
    # Cargar con PIL para respetar EXIF, luego pasar a OpenCV
    pil_img = _corregir_exif(Image.open(img_path).convert('RGB'))
    cv_img  = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

    # Recortar a la boleta
    cv_crop = _detectar_boleta(cv_img)

    # Convertir a gris y mejorar para OCR
    cv_gray = cv2.cvtColor(cv_crop, cv2.COLOR_BGR2GRAY)
    cv_ocr  = _mejorar_para_ocr(cv_gray)

    # Devolver como PIL
    return Image.fromarray(cv_ocr)


def save_processed(img, orig_path: str) -> str:
    base     = os.path.basename(orig_path)
    name, _  = os.path.splitext(base)
    out_dir  = os.path.normpath(os.path.join(os.path.dirname(orig_path), '..', 'processed'))
    os.makedirs(out_dir, exist_ok=True)

    # Asegurar escala de grises
    if img.mode != 'L':
        img = img.convert('L')

    # Redimensionar a máx 900px de ancho manteniendo proporción
    max_w = 900
    if img.width > max_w:
        ratio = max_w / img.width
        img = img.resize((max_w, int(img.height * ratio)), Image.LANCZOS)

    # Guardar como JPEG (WebP convierte 'L' a RGB en Pillow, perdiendo el B&N)
    out_path = os.path.join(out_dir, name + '_proc.jpg')
    img.save(out_path, 'JPEG', quality=85)
    return 'processed/' + name + '_proc.jpg'

# ── EXTRACCIÓN CON CONFIANZA ──────────────────────────────────────────────────

def extract_monto(text: str, patrones_aprendidos: list = None):
    """Retorna (monto, confianza 0-1)"""
    t = text.lower()

    # Normalizar: reemplazar caracteres OCR que confunden $ (3, S, s, §) cuando
    # van seguidos de número y no están dentro de una palabra
    t = re.sub(r'(?<![a-zA-Z0-9])([3sS§])\s+(?=[\d])', '$ ', t)

    # ── Patrones aprendidos de esta empresa (máxima prioridad) ────────────────
    if patrones_aprendidos:
        for item in sorted(patrones_aprendidos, key=lambda x: x.get('hits', 0), reverse=True):
            etiqueta = re.escape(item['patron'].lower())
            pat = rf'{etiqueta}\s*[:\$]?\s*\$?\s*([\d\.\,]+)'
            m = re.search(pat, t)
            if m:
                v = _parse_monto(m.group(1))
                if v: return v, 0.95

    # ── Patrones universales alta confianza ───────────────────────────────────
    # Nota: [,\s]* entre el separador y el símbolo de moneda maneja formatos
    # como "TOTAL: , $1.200" (common in Transbank/POS receipts)
    patrones_alta = [
        r'total\s+a\s+pagar\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'total\s*:?\s*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        # "a pagar" mismo línea con $ explícito
        r'a\s+pagar\s*[:\s]*[,\s]*[\$3sS§]\s*([\d\.\,]+)',
        # "a pagar" con monto en línea siguiente (posible ruido antes del $)
        r'a\s+pagar\s*\n[^\n]*[\$3sS§]\s*([\d\.\,]+)',
        r'importe\s+total\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'pagado\s+([\d\.\,]+)\s*(?:clp|cop|usd)?',
        r'monto\s+total\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'monto\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
    ]
    patrones_media = [
        r'subtotal\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'precio\s+total\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'neto\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'valor\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'cobrado\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'cancelado\s*[:\s]*[,\s]*[\$3sS§]?\s*([\d\.\,]+)',
        r'tarifa\s+([\d\.\,]+)\s*(?:clp)?',
    ]
    # Patrones de último recurso: buscar en líneas que contengan $ o CLP
    # para evitar capturar números de referencia/transacción
    patrones_baja = [
        r'[\$]\s*([\d\.\,]+)',          # literal $
        r'[\$3sS§]\s*([\d\.\,]+)',      # $ o confundible
        r'([\d]{1,3}(?:[\.\,]\d{3})+)\s*(?:clp)?',  # formato 1.234 o 1.234 clp
    ]

    for pat in patrones_alta:
        m = re.search(pat, t)
        if m:
            v = _parse_monto(m.group(1), skip_year_check=True)
            if v: return v, 0.90

    for pat in patrones_media:
        m = re.search(pat, t)
        if m:
            v = _parse_monto(m.group(1), skip_year_check=True)
            if v: return v, 0.75

    # Para baja confianza: buscar el ÚLTIMO número con formato de monto en líneas
    # que contengan indicadores de precio, para evitar IDs de transacción
    for linea in reversed(t.split('\n')):
        if not re.search(r'[\$\#]|clp|total|pago|cobro|precio|valor', linea):
            continue
        m = re.search(r'[\$3sS§]\s*([\d\.\,]+)', linea)
        if m:
            v = _parse_monto(m.group(1))
            if v: return v, 0.50

    for pat in patrones_baja:
        # Buscar de derecha a izquierda (el total suele estar al final)
        matches = list(re.finditer(pat, t))
        for m in reversed(matches):
            v = _parse_monto(m.group(1))
            if v: return v, 0.40

    return None, 0.0


def _parse_monto(raw: str, skip_year_check: bool = False):
    raw = raw.strip().rstrip('.,')  # quitar puntuación final (OCR noise: "20,000,")
    # Formato chileno punto-miles:  1.234  o  1.234,56
    if re.match(r'^\d{1,3}(\.\d{3})+(,\d{1,2})?$', raw):
        raw = raw.replace('.', '').replace(',', '.')
    # Formato coma-miles inglés/POS: 20,000  o  1,234
    elif re.match(r'^\d{1,3}(,\d{3})+$', raw):
        raw = raw.replace(',', '')
    else:
        raw = raw.replace(',', '.')
    try:
        v = float(raw)
        # Rango válido para montos en CLP
        if v < 50 or v > 10_000_000:
            return None
        # Rechazar números que parecen años (1900-2099 enteros sin centavos)
        # Solo aplica en contexto de baja confianza: si hay keyword clara, confiamos
        if not skip_year_check and 1900 <= v <= 2099 and v == int(v):
            return None
        return v
    except ValueError:
        return None


def extract_fecha(text: str, patrones_aprendidos: list = None):
    """Retorna (fecha_str 'YYYY-MM-DD', confianza)"""
    hoy = datetime.today()
    fecha_pattern = r'(\d{1,2})\s*[\/\-\.]\s*(\d{1,2})\s*[\/\-\.]\s*(\d{2,4})'

    def _try_fecha(d, mo, y_str):
        if len(y_str) == 2: y_str = '20' + y_str
        y = int(y_str)
        if not (1 <= mo <= 12 and 1 <= d <= 31): return None
        try:
            fecha = datetime(y, mo, d)
            if fecha.date() <= hoy.date():
                return fecha.strftime('%Y-%m-%d')
        except ValueError:
            pass
        return None

    # ── Buscar primero en líneas con etiquetas aprendidas ─────────────────────
    if patrones_aprendidos:
        for item in sorted(patrones_aprendidos, key=lambda x: x.get('hits', 0), reverse=True):
            etiqueta = re.escape(item['patron'].lower())
            for linea in text.lower().split('\n'):
                if not re.search(etiqueta, linea):
                    continue
                for m in re.finditer(fecha_pattern, linea):
                    r = _try_fecha(int(m.group(1)), int(m.group(2)), m.group(3))
                    if r: return r, 0.95

    # dd/mm/yyyy o dd-mm-yyyy o dd.mm.yyyy (espacios opcionales)
    for m in re.finditer(fecha_pattern, text):
        r = _try_fecha(int(m.group(1)), int(m.group(2)), m.group(3))
        if r: return r, 0.90

    # Fecha con ruido OCR entre partes: "18/12tooes2024" → buscar año tras basura
    noisy_pat = r'(\d{1,2})[\/\-](\d{1,2})[^\d\n]{0,12}(\d{4})'
    for m in re.finditer(noisy_pat, text):
        r = _try_fecha(int(m.group(1)), int(m.group(2)), m.group(3))
        if r: return r, 0.70

    # dd de mes de yyyy (texto en español)
    for m in re.finditer(r'(\d{1,2})\s+de\s+(\w+)\s+(?:de\s+)?(\d{4})', text, re.IGNORECASE):
        d, mes_str, y = int(m.group(1)), m.group(2).lower(), m.group(3)
        mo = MESES.get(mes_str)
        if mo:
            r = _try_fecha(d, int(mo), y)
            if r: return r, 0.85

    # dd-mes_abrev-yyyy (p.ej. "18-DIC-2024" en terminales POS)
    for m in re.finditer(r'(\d{1,2})[\/\-]([a-záéíóúñA-Z]{3})[\/\-](\d{2,4})', text, re.IGNORECASE):
        mes_str = m.group(2).lower()
        mo = MESES.get(mes_str)
        if mo:
            r = _try_fecha(int(m.group(1)), int(mo), m.group(3))
            if r: return r, 0.85

    # Solo mes/año
    m = re.search(r'(\d{1,2})[\/\-](\d{4})', text)
    if m:
        try:
            fecha = datetime(int(m.group(2)), int(m.group(1)), 1)
            if fecha.date() <= hoy.date():
                return fecha.strftime('%Y-%m-01'), 0.55
        except ValueError:
            pass

    # Último recurso: solo dd/mm → asumir año actual (o anterior si es futuro)
    m = re.search(r'\b(\d{1,2})[\/\-](\d{1,2})\b', text)
    if m:
        d_val, mo_val = int(m.group(1)), int(m.group(2))
        if 1 <= mo_val <= 12 and 1 <= d_val <= 31:
            y = hoy.year
            try:
                fecha = datetime(y, mo_val, d_val)
                if fecha.date() > hoy.date():
                    fecha = datetime(y - 1, mo_val, d_val)
                return fecha.strftime('%Y-%m-%d'), 0.45
            except ValueError:
                pass

    return None, 0.0


def extract_proveedor(text: str):
    """Retorna (proveedor, confianza)"""
    lines = [l.strip() for l in text.split('\n') if l.strip()]
    for line in lines[:6]:
        # Ignorar líneas que son solo números, fechas o símbolos
        if re.match(r'^[\d\$\.\,\-\/\*\#\@\s]+$', line):
            continue
        # Requiere al menos 3 letras consecutivas (filtra ruido OCR como "oo,", ",,", "| |")
        if not re.search(r'[a-zA-ZáéíóúñÁÉÍÓÚÑ]{3,}', line):
            continue
        # Palabras clave que indican que es el encabezado del negocio
        if re.search(r'(rut|nit|telefono|tel:|fono|www\.|http|boleta|ticket|factura)', line, re.I):
            continue
        return line[:80], 0.70
    return None, 0.0


def suggest_categoria(text: str, categorias: list):
    """Retorna (categoria_id, confianza, keyword_matches)"""
    text_lower = text.lower()
    best_id    = None
    best_score = 0
    best_kw    = []

    for cat in categorias:
        if not cat.get('keywords'):
            continue
        matches = []
        for kw in cat['keywords'].split(','):
            kw = kw.strip().lower()
            if kw and kw in text_lower:
                matches.append(kw)

        # Score: más matches y más hits = mayor confianza
        if matches:
            hits = int(cat.get('hits', 1) or 1)
            score = len(matches) * 0.3 + min(hits / 10, 0.4) + 0.3
            score = min(score, 0.95)
            if score > best_score:
                best_score = score
                best_id    = cat['id']
                best_kw    = matches

    return best_id, round(best_score, 2), best_kw

# ── MAIN ──────────────────────────────────────────────────────────────────────
def main():
    if len(sys.argv) < 2:
        print(json.dumps({"error": "Uso: process_receipt.py <imagen> [categorias_json] [patrones_json]"}))
        sys.exit(1)

    img_path   = sys.argv[1]
    categorias = []
    patrones   = {'monto': [], 'fecha': []}

    if len(sys.argv) >= 3:
        try:
            categorias = json.loads(sys.argv[2])
        except json.JSONDecodeError:
            pass

    if len(sys.argv) >= 4:
        try:
            patrones = json.loads(sys.argv[3])
        except json.JSONDecodeError:
            pass

    if not os.path.exists(img_path):
        print(json.dumps({"error": f"Imagen no encontrada: {img_path}"}))
        sys.exit(1)

    try:
        processed_img = preprocess(img_path)

        config = '--oem 3 --psm 6 -l spa+eng'
        texto  = pytesseract.image_to_string(processed_img, config=config)

        # Segundo intento si el texto es muy corto
        if len(texto.strip()) < 80:
            texto2 = pytesseract.image_to_string(processed_img, config='--oem 3 --psm 4 -l spa+eng')
            if len(texto2.strip()) > len(texto.strip()):
                texto = texto2

        # Limpiar: eliminar bytes nulos y secuencias inválidas que rompen json_decode en PHP
        texto = texto.replace('\x00', '').encode('utf-8', errors='replace').decode('utf-8')

        monto,    monto_conf    = extract_monto(texto, patrones.get('monto', []))
        fecha,    fecha_conf    = extract_fecha(texto, patrones.get('fecha', []))
        proveedor, prov_conf   = extract_proveedor(texto)
        cat_id, cat_conf, kws  = suggest_categoria(texto, categorias)

        # Guardar imagen procesada (si falla por permisos no interrumpe el OCR)
        processed_path = None
        try:
            processed_path = save_processed(processed_img, img_path)
        except Exception:
            pass

        print(json.dumps({
            "ok":                 True,
            "texto_ocr":          texto.strip(),
            "monto":              monto,
            "monto_confianza":    monto_conf,
            "fecha_boleta":       fecha,
            "fecha_confianza":    fecha_conf,
            "proveedor":          proveedor,
            "proveedor_confianza": prov_conf,
            "categoria_id":       cat_id,
            "categoria_confianza": cat_conf,
            "categoria_keywords": kws,
            "imagen_procesada":   processed_path,
        }, ensure_ascii=False))

    except Exception as e:
        print(json.dumps({"error": str(e)}))
        sys.exit(1)


if __name__ == '__main__':
    main()
