"""Test Tier 1 system: speed and accuracy"""
import torch
import numpy as np
import time
from tier1_system import Tier1FastSystem
from phi_constants import INV_PHI

print("="*60)
print("  🌀 TIER 1 (Fast ESN) Test")
print("="*60)

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}\n")

# Create system
system = Tier1FastSystem(input_size=1, output_size=1, device=device)

# Verify spectral radii
print("Spectral Radii (target = {:.6f}):".format(INV_PHI))
for i, (name, sr) in enumerate(zip(system.layer_names, system.get_all_spectral_radii())):
    print(f"  Layer {i} ({name:8s}): {sr:.6f}")
print()

# Test data
batch_size = 16
test_input = torch.randn(batch_size, 1, dtype=torch.float32).to(device) * 0.1

# Speed test
print("Speed Test:")
system.reset_all_states(batch_size)

start = time.time()
for _ in range(100):
    result = system(test_input, return_resonance=True)
end = time.time()

avg_time_ms = (end - start) / 100 * 1000
print(f"  Average inference: {avg_time_ms:.4f} ms")
print(f"  Throughput: {1000/avg_time_ms:.1f} inferences/sec")
print()

# Resonance test
print("Resonance Test:")
system.reset_all_states(batch_size)

# Test 1: Random input (low resonance expected)
random_input = torch.randn(batch_size, 1, dtype=torch.float32).to(device)
result = system(random_input, return_resonance=True)
print(f"  Random input resonance: {result['resonance']:.4f}")

# Test 2: Repeated pattern (high resonance expected)
pattern = torch.tensor([[0.5]] * batch_size, dtype=torch.float32).to(device)
system.reset_all_states(batch_size)
for _ in range(10):
    result = system(pattern, return_resonance=True)
print(f"  Repeated pattern resonance: {result['resonance']:.4f}")
print()

# Accuracy test (simple next-step prediction)
print("Accuracy Test:")

# Generate sine wave
t = np.linspace(0, 20*np.pi, 500)
data = np.sin(t).astype(np.float32)

def get_batch(seq_len=20, batch=16):
    X = np.zeros((batch, seq_len, 1), dtype=np.float32)
    Y = np.zeros((batch, seq_len, 1), dtype=np.float32)
    
    for i in range(batch):
        start = np.random.randint(0, len(data) - seq_len - 1)
        X[i, :, 0] = data[start:start+seq_len]
        Y[i, :, 0] = data[start+1:start+seq_len+1]
    
    return (torch.from_numpy(X).to(device), 
            torch.from_numpy(Y).to(device))

# Quick training
optimizer = torch.optim.Adam(system.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

print("  Training for 50 epochs...")
for epoch in range(50):
    X, Y = get_batch()
    
    system.reset_all_states(batch_size=16)
    
    losses = []
    for t in range(X.shape[1]):
        result = system(X[:, t, :], return_resonance=False)
        loss = loss_fn(result['prediction'], Y[:, t, :])
        losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"    Epoch {epoch+1}: Loss = {np.mean(losses):.6f}")

print()
print("="*60)
print("  ✅ Tier 1 System Operational")
print("="*60)
print(f"  Speed: {avg_time_ms:.4f} ms per inference")
print(f"  Layers: 3 (Trinity, Nyx, Ava)")
print(f"  Sizes: 8, 13, 21 neurons")
print(f"  Spectral radius: ~{INV_PHI:.3f} (1/φ)")
print("="*60)
