#!/usr/bin/env python3
"""
Transfer Learning - Adapt knowledge across domains
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer

class TransferLearning:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Loading pre-trained model...")
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
        print("✅ Transfer learning ready!")
        
        self.source_knowledge = {}
    
    def learn_source_task(self, examples, task_name):
        """Learn from source domain"""
        self.source_knowledge[task_name] = []
        
        for text, label in examples:
            emb = self.encoder.encode(text, convert_to_tensor=True).cpu()
            self.source_knowledge[task_name].append((emb, label))
    
    def transfer_to_target(self, query, source_task):
        """Transfer knowledge to new domain"""
        query_emb = self.encoder.encode(query, convert_to_tensor=True).cpu()
        
        if source_task not in self.source_knowledge:
            return None, 0
        
        # Find most similar example
        best_score = -1
        best_label = None
        
        for emb, label in self.source_knowledge[source_task]:
            score = F.cosine_similarity(query_emb.unsqueeze(0), emb.unsqueeze(0)).item()
            if score > best_score:
                best_score = score
                best_label = label
        
        return best_label, best_score

def test_transfer():
    print("\n" + "="*70)
    print("TESTING TRANSFER LEARNING")
    print("="*70)
    
    tl = TransferLearning()
    
    # Source task: sentiment
    source = [
        ("This is great", "positive"),
        ("This is bad", "negative"),
        ("Love it", "positive"),
        ("Hate it", "negative"),
    ]
    
    tl.learn_source_task(source, "sentiment")
    
    # Target: transfer to product reviews
    tests = [
        ("Amazing product", "positive"),
        ("Terrible quality", "negative"),
        ("Wonderful experience", "positive"),
    ]
    
    passed = 0
    for query, expected in tests:
        result, conf = tl.transfer_to_target(query, "sentiment")
        print(f"\nQuery: {query}")
        print(f"Predicted: {result} ({conf:.3f})")
        
        if result == expected:
            print("✅ Correct!")
            passed += 1
    
    print(f"\n{passed}/{len(tests)} ({100*passed/len(tests):.1f}%)")
    
    if passed >= 2:
        print("✅ Transfer working!")
        return True
    return False

def main():
    if test_transfer():
        print("\n✅ CAPABILITY #14 COMPLETE")

if __name__ == "__main__":
    main()
