# apps/usage/middleware.py
import re
from django.utils.deprecation import MiddlewareMixin
from django.http import JsonResponse
from django.utils import timezone
from django.db import models
from apps.usage.models import UsageLog
from apps.apps_.models import APIKey, App
from apps.companies.models import Company

# Endpoint map: path prefix → (endpoint_name, weight)
ENDPOINT_MAP = {
    # Companies & Apps
    "/api/v1/companies/apps/create/": ("app_create", 9),
    "/api/v1/companies/apps/create-api-key/": ("apikey_create", 9),
    "/api/v1/companies/apps/": ("apps", 9),

    # Reviews
    "/api/v1/reviews/link/": ("reviews_link", 9),
    "/api/v1/reviews/tour/": ("tour_reviews", 6),
    "/api/v1/reviews/submit/": ("reviews_submit", 9),
    "/api/v1/reviews/analytics/": ("reviews_analytics", 9),
    "/api/v1/reviews/": ("reviews", 6),

    # Branding
    "/api/v1/reviews/apps/": ("reviews_app_branding", 9),
}


class UsageTrackingMiddleware(MiddlewareMixin):
    def process_request(self, request):
        request._api_key_obj = None
        request._company_obj = None
        request._endpoint_name = None
        request._weight = 1

        path = request.path

        # Identify endpoint
        for prefix, (name, weight) in ENDPOINT_MAP.items():
            if path.startswith(prefix):
                request._endpoint_name = name
                request._weight = weight
                break

        # API key auth
        api_key_value = request.headers.get("X-API-KEY")
        if api_key_value:
            try:
                api_key = APIKey.objects.get(key=api_key_value)
                request._api_key_obj = api_key
                request._company_obj = api_key.app.company
            except APIKey.DoesNotExist:
                return JsonResponse({"detail": "Invalid API Key"}, status=403)

        # JWT auth fallback
        elif request.user and request.user.is_authenticated:
            request._company_obj = getattr(request.user, "company", None)

        # Quota enforcement + plan restrictions
        if request._company_obj and request._endpoint_name:
            company = request._company_obj

            # ---- Free Plan Restrictions ----
            if company.plan == "free":
                # Restrict to 1 app
                if path.startswith("/api/v1/companies/apps/create/"):
                    if App.objects.filter(company=company).exists():
                        return JsonResponse(
                            {"detail": "Free plan allows only 1 app. Upgrade to create more."},
                            status=403,
                        )

                # Restrict branding APIs to paid plans only
                if re.match(r"^/api/v1/reviews/apps/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/branding/?$", path):
                    return JsonResponse(
                        {
                            "detail": "Branding APIs are only available on the paid plan. Please upgrade.",
                            "error_code": "UPGRADE_REQUIRED"
                        },
                        status=403,
                    )

                # Restrict to 1 API key per app
                if path.startswith("/api/v1/companies/apps/create-api-key/"):
                    app_id = request.POST.get("app_id") or request.GET.get("app_id")
                    if app_id:
                        try:
                            app = App.objects.get(id=app_id, company=company)
                            if APIKey.objects.filter(app=app).exists():
                                return JsonResponse(
                                    {"detail": "Free plan allows only 1 API key per app. Upgrade to create more."},
                                    status=403,
                                )
                        except App.DoesNotExist:
                            pass  # handled by view later

                # Monthly quota enforcement
                now = timezone.now()
                start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)

                used = (
                    UsageLog.objects.filter(company=company, timestamp__gte=start_of_month)
                    .aggregate(total=models.Sum("weight"))["total"]
                    or 0
                )
                projected = used + request._weight

                if projected > 1000:
                    return JsonResponse(
                        {"detail": "API quota exceeded. Upgrade to paid plan."},
                        status=403,
                    )

        return None  # continue to view

    def process_response(self, request, response):
        try:
            if hasattr(request, "_company_obj") and request._company_obj:
                UsageLog.objects.create(
                    company=request._company_obj,
                    api_key=getattr(request, "_api_key_obj", None),
                    path=request.path,
                    method=request.method,
                    status_code=response.status_code,
                    weight=getattr(request, "_weight", 1),
                    endpoint_name=getattr(request, "_endpoint_name", None),
                )
        except Exception:
            pass
        return response
