from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import asyncio
import json
import httpx
from typing import AsyncIterator

app = FastAPI()

# Enable CORS for your frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3001", "http://127.0.0.1:3001"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configuration for Ollama
OLLAMA_URL = "http://localhost:11434/api/generate"
MODEL_NAME = "llama2"  # Change to your model

async def stream_ollama_response(prompt: str) -> AsyncIterator[str]:
    """Stream responses from Ollama token by token"""
    async with httpx.AsyncClient(timeout=60.0) as client:
        async with client.stream(
            "POST",
            OLLAMA_URL,
            json={
                "model": MODEL_NAME,
                "prompt": prompt,
                "stream": True
            }
        ) as response:
            async for line in response.aiter_lines():
                if line.strip():
                    try:
                        data = json.loads(line)
                        if "response" in data:
                            # Send each token immediately
                            yield data["response"]
                    except json.JSONDecodeError:
                        continue

@app.post("/api/chat/stream")
async def chat_stream(request: dict):
    """HTTP streaming endpoint - sends Server-Sent Events"""
    message = request.get("message", "")
    
    async def event_generator():
        async for token in stream_ollama_response(message):
            # Send as Server-Sent Event format
            yield f"data: {json.dumps({'token': token})}\n\n"
        yield f"data: {json.dumps({'done': True})}\n\n"
    
    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
        }
    )

@app.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
    """WebSocket endpoint - even faster, bidirectional"""
    await websocket.accept()
    print("WebSocket client connected")
    
    try:
        while True:
            # Receive message from frontend
            data = await websocket.receive_json()
            message = data.get("message", "")
            
            # Stream response back token by token
            async for token in stream_ollama_response(message):
                await websocket.send_json({
                    "type": "token",
                    "content": token
                })
            
            # Signal completion
            await websocket.send_json({
                "type": "done"
            })
            
    except WebSocketDisconnect:
        print("WebSocket client disconnected")

@app.get("/health")
async def health():
    return {"status": "ok"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
