from sqlalchemy import Column, String, Integer, DateTime, JSON, ForeignKey, Float, Boolean, Text
from sqlalchemy.sql import func
from database.session import Base
from sqlalchemy.orm import relationship

class SubscriberProfile(Base):
    __tablename__ = "subscriber_profiles"
    firebase_uid = Column(String, primary_key=True, index=True)
    email = Column(String, unique=True, index=True)
    display_name = Column(String, nullable=True)
    company_name = Column(String, nullable=True)
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), onupdate=func.now())

class UserRole(Base):
    __tablename__ = "user_roles"
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), primary_key=True)
    role = Column(String, default="user") # admin, subscriber, user
    permissions_list = Column(JSON, default=[])

class Subscription(Base):
    __tablename__ = "subscriptions"
    id = Column(Integer, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), index=True)
    stripe_customer_id = Column(String, nullable=True)
    plan_id = Column(String, nullable=True)
    status = Column(String, default="inactive")
    current_period_end = Column(DateTime(timezone=True), nullable=True)

class APIKey(Base):
    __tablename__ = "api_keys"
    key_id = Column(String, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"))
    hashed_key = Column(String, nullable=False)
    name = Column(String)
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    last_used_at = Column(DateTime(timezone=True), nullable=True)
    is_active = Column(Boolean, default=True)

class AIUsageLog(Base):
    __tablename__ = "ai_usage_logs"
    log_id = Column(Integer, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), index=True)
    model_used = Column(String)
    tokens_prompt = Column(Integer, default=0)
    tokens_completion = Column(Integer, default=0)
    cost = Column(Float, default=0.0)
    timestamp = Column(DateTime(timezone=True), server_default=func.now())

class PromptHistory(Base):
    __tablename__ = "prompt_history"
    prompt_id = Column(Integer, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), index=True)
    session_id = Column(String, index=True)
    user_message = Column(Text)
    ai_response = Column(Text)
    timestamp = Column(DateTime(timezone=True), server_default=func.now())

class VectorIndex(Base):
    __tablename__ = "vector_indexes"
    index_id = Column(Integer, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), index=True)
    collection_name = Column(String)
    document_count = Column(Integer, default=0)
    vector_db_type = Column(String, default="chroma")

class AuditLog(Base):
    __tablename__ = "audit_logs"
    log_id = Column(Integer, primary_key=True, index=True)
    firebase_uid = Column(String, ForeignKey("subscriber_profiles.firebase_uid"), index=True)
    action = Column(String)
    resource = Column(String)
    ip_address = Column(String, nullable=True)
    timestamp = Column(DateTime(timezone=True), server_default=func.now())
