"""
Subscription Service — Firestore-based usage tracking.
Redis has been removed. All rate limiting uses Firestore atomic increments.
"""
import logging
from datetime import datetime
from typing import Optional

logger = logging.getLogger("chatbot.subscription_service")

# ── Tier Definitions ───────────────────────────────────────────────────────────
TIER_LIMITS: dict = {
    "free": {
        "max_domains": 1,
        "max_daily_messages": 50,
        "advanced_support_limit": 0,
        "supports_custom_prompts": False,
        "label": "Free",
        "price_monthly": 0,
    },
    "starter": {
        "max_domains": 2,
        "max_daily_messages": 100,
        "advanced_support_limit": 20,
        "supports_custom_prompts": True,
        "label": "Starter",
        "price_monthly": 9,
    },
    "pro": {
        "max_domains": 5,
        "max_daily_messages": 1000,
        "advanced_support_limit": 300,
        "supports_custom_prompts": True,
        "label": "Pro",
        "price_monthly": 29,
    },
    "enterprise": {
        "max_domains": 9999,
        "max_daily_messages": 999999,
        "advanced_support_limit": -1,   # unlimited
        "supports_custom_prompts": True,
        "label": "Enterprise",
        "price_monthly": 99,
    },
}


def get_tier_limits(tier: str) -> dict:
    return TIER_LIMITS.get(tier.lower(), TIER_LIMITS["free"])


def can_create_domain(current_domain_count: int, tier: str) -> bool:
    """Checks if a subscriber can register another domain under their tier."""
    return current_domain_count < get_tier_limits(tier)["max_domains"]


def supports_custom_persona(tier: str) -> bool:
    """Checks if the subscriber's tier allows custom chatbot system prompts."""
    return get_tier_limits(tier)["supports_custom_prompts"]


def can_use_advanced_support(tier: str) -> bool:
    """Returns True if the tier includes any advanced (AI-generated) support."""
    return get_tier_limits(tier)["advanced_support_limit"] != 0


# ── Firestore Usage Tracking ──────────────────────────────────────────────────

def _today_key() -> str:
    return datetime.utcnow().strftime("%Y-%m-%d")


async def get_daily_usage(uid: str, db) -> dict:
    """
    Reads today's usage counters for a subscriber from Firestore.
    Returns {message_count, advanced_count}.
    """
    try:
        today = _today_key()
        doc = db.collection("usage").document(uid).collection("daily").document(today).get()
        if doc.exists:
            data = doc.to_dict()
            return {
                "message_count": data.get("message_count", 0),
                "advanced_count": data.get("advanced_count", 0),
            }
        return {"message_count": 0, "advanced_count": 0}
    except Exception as e:
        logger.error(f"Usage read error for {uid}: {e}")
        return {"message_count": 0, "advanced_count": 0}


async def increment_message_count(uid: str, db) -> int:
    """
    Atomically increments the daily message counter for a subscriber.
    Returns the new count.
    """
    try:
        from google.cloud.firestore_v1 import transforms
        today = _today_key()
        ref = db.collection("usage").document(uid).collection("daily").document(today)
        ref.set(
            {"message_count": transforms.Increment(1)},
            merge=True
        )
        snap = ref.get()
        return snap.to_dict().get("message_count", 1) if snap.exists else 1
    except Exception as e:
        logger.error(f"Usage increment error for {uid}: {e}")
        return 0


async def increment_advanced_count(uid: str, db) -> int:
    """Atomically increments the advanced-support counter for a subscriber."""
    try:
        from google.cloud.firestore_v1 import transforms
        today = _today_key()
        ref = db.collection("usage").document(uid).collection("daily").document(today)
        ref.set(
            {"advanced_count": transforms.Increment(1)},
            merge=True
        )
        snap = ref.get()
        return snap.to_dict().get("advanced_count", 1) if snap.exists else 1
    except Exception as e:
        logger.error(f"Advanced usage increment error for {uid}: {e}")
        return 0


async def is_daily_limit_exceeded(uid: str, tier: str, db) -> bool:
    """Returns True if the subscriber has hit their daily message limit."""
    limits = get_tier_limits(tier)
    max_messages = limits["max_daily_messages"]
    if max_messages == -1:
        return False
    usage = await get_daily_usage(uid, db)
    return usage["message_count"] >= max_messages


async def is_advanced_limit_exceeded(uid: str, tier: str, db) -> bool:
    """Returns True if the subscriber has exhausted their advanced support quota."""
    limits = get_tier_limits(tier)
    adv_limit = limits["advanced_support_limit"]
    if adv_limit == -1:
        return False   # enterprise — unlimited
    if adv_limit == 0:
        return True    # free — no advanced support at all
    usage = await get_daily_usage(uid, db)
    return usage["advanced_count"] >= adv_limit
