from django.utils import timezone
from django.db.models import Sum
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated

from apps.usage.models import UsageLog
from .serializers import UsageSummarySerializer


class UsageSummaryView(APIView):
    """
    Return usage summary for the authenticated user's company.
    Works with JWT authentication by getting company from the user.
    """

    permission_classes = [IsAuthenticated]

    def get(self, request):
        # Get company from authenticated user
        user = request.user
        
        # Try different ways to get company based on your user model structure
        company = None
        
        # Option 1: Direct relationship
        if hasattr(user, 'company') and user.company:
            company = user.company
        
        # Option 2: Through profile
        elif hasattr(user, 'profile') and hasattr(user.profile, 'company'):
            company = user.profile.company
            
        # Option 3: Through membership/relation
        elif hasattr(user, 'company_memberships'):
            membership = user.company_memberships.filter(is_active=True).first()
            if membership:
                company = membership.company
                
        # Option 4: User has a company_id field
        elif hasattr(user, 'company_id') and user.company_id:
            from apps.companies.models import Company
            try:
                company = Company.objects.get(id=user.company_id)
            except Company.DoesNotExist:
                pass

        if not company:
            return Response({"detail": "No company associated with user"}, status=400)

        # Month filter
        month_str = request.query_params.get("month")
        now = timezone.now()

        if month_str:
            try:
                year, month = map(int, month_str.split("-"))
                start_of_month = timezone.datetime(year, month, 1, tzinfo=timezone.utc)
            except ValueError:
                return Response({"detail": "Invalid month format. Use YYYY-MM."}, status=400)
        else:
            start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)

        logs = UsageLog.objects.filter(company=company, timestamp__gte=start_of_month)

        # Aggregate totals
        total_requests = logs.aggregate(total=Sum("weight"))["total"] or 0
        if company.plan == "free":
            quota = 1000
        elif company.plan == "paid":
            quota = 5000
        else:
            quota = None
        remaining = (quota - total_requests) if quota is not None else None

        # Per-endpoint usage
        by_endpoint_qs = logs.values("endpoint_name").annotate(total=Sum("weight"))
        by_endpoint = {row["endpoint_name"] or "unknown": row["total"] for row in by_endpoint_qs}

        data = {
            "company": company.name,
            "plan": company.plan,
            "month": start_of_month.strftime("%Y-%m"),
            "total_requests": total_requests,
            "quota": quota,
            "remaining": remaining,
            "by_endpoint": by_endpoint,
        }

        return Response(UsageSummarySerializer(data).data)