#!/usr/bin/env python3
"""
Creative Generation
"""
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import random

class CreativeGenerator:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Loading models...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        print("✅ Creative generation ready!")
        self.concept_bank = {
            'colors': ['red', 'blue', 'green', 'yellow', 'purple'],
            'objects': ['car', 'house', 'tree', 'mountain', 'river'],
            'actions': ['flying', 'dancing', 'growing', 'shining', 'flowing'],
            'qualities': ['bright', 'dark', 'fast', 'slow', 'large']
        }
    
    def generate_novel_combination(self):
        quality = random.choice(self.concept_bank['qualities'])
        obj = random.choice(self.concept_bank['objects'])
        action = random.choice(self.concept_bank['actions'])
        return f"{quality} {obj} {action}"
    
    def evaluate_novelty(self, concept, known_concepts):
        concept_emb = self.encoder.encode(concept, convert_to_tensor=True).cpu()
        max_sim = 0
        for known in known_concepts:
            known_emb = self.encoder.encode(known, convert_to_tensor=True).cpu()
            sim = F.cosine_similarity(concept_emb.unsqueeze(0), known_emb.unsqueeze(0)).item()
            max_sim = max(max_sim, sim)
        return 1 - max_sim

def test_creative():
    print("\n" + "="*70)
    print("TESTING CREATIVE GENERATION")
    print("="*70)
    cg = CreativeGenerator()
    known = ["red car", "blue house", "green tree"]
    print("\nGenerating novel concepts...")
    novel_count = 0
    for i in range(5):
        concept = cg.generate_novel_combination()
        novelty = cg.evaluate_novelty(concept, known)
        print(f"{i+1}. {concept} (novelty: {novelty:.3f})")
        if novelty > 0.3:
            novel_count += 1
    print(f"\nNovel concepts: {novel_count}/5")
    if novel_count >= 3:
        print("✅ Creative generation working!")
        return True
    return False

def main():
    if test_creative():
        print("\n✅ CAPABILITY #17 COMPLETE")

if __name__ == "__main__":
    main()
