#!/usr/bin/env python3
"""Analyze voice from QQ voice message - extract characteristics and save."""
import sys, os, json, subprocess, tempfile

def download_voice(url, output_path):
    """Download QQ voice file from temporary URL."""
    import urllib.request
    req = urllib.request.Request(url, headers={
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
    })
    try:
        resp = urllib.request.urlopen(req, timeout=20)
        with open(output_path, 'wb') as f:
            f.write(resp.read())
        return os.path.getsize(output_path)
    except Exception as e:
        print(f"Download failed: {e}")
        return 0

def convert_to_wav(amr_path, wav_path):
    """Convert AMR to WAV using ffmpeg."""
    result = subprocess.run(
        ['ffmpeg', '-y', '-i', amr_path, '-ar', '22050', '-ac', '1', wav_path],
        capture_output=True, text=True, timeout=30
    )
    return result.returncode == 0

def analyze_voice(wav_path):
    """Analyze voice characteristics."""
    import numpy as np
    from scipy import signal as ss
    from scipy.io import wavfile
    
    sr, data = wavfile.read(wav_path)
    if data.dtype != np.float32:
        data = data.astype(np.float32) / np.iinfo(data.dtype).max
    
    duration = len(data) / sr
    
    # Basic stats
    rms = np.sqrt(np.mean(data**2))
    peak = np.max(np.abs(data))
    
    # Pitch detection using autocorrelation
    # Find fundamental frequency
    autocorr = np.correlate(data, data, mode='full')
    autocorr = autocorr[len(autocorr)//2:]
    
    # Find first peak after min period for 20Hz (f0 min = ~50Hz)
    min_idx = int(sr / 500)  # max 500Hz
    max_idx = int(sr / 50)   # min 50Hz
    if max_idx < len(autocorr):
        ac_seg = autocorr[min_idx:max_idx]
        # Apply parabolic interpolation for fine pitch
        peak_idx = np.argmax(ac_seg) + min_idx
        f0 = sr / peak_idx if peak_idx > 0 else 0
    else:
        f0 = 0
    
    # Spectrogram analysis
    freqs, times, Sxx = ss.spectrogram(data, sr, nperseg=1024, noverlap=512)
    
    # Spectral centroid (brightness)
    spectral_centroid = np.sum(freqs[:, np.newaxis] * Sxx, axis=0) / (np.sum(Sxx, axis=0) + 1e-10)
    avg_centroid = np.mean(spectral_centroid)
    
    # Spectral rolloff (frequency below which 85% energy)
    cumsum = np.cumsum(Sxx, axis=0)
    total_energy = cumsum[-1, :]
    rolloff_point = 0.85
    rolloff_freqs = []
    for i in range(cumsum.shape[1]):
        idx = np.searchsorted(cumsum[:, i], rolloff_point * total_energy[i])
        if idx < len(freqs):
            rolloff_freqs.append(freqs[idx])
    avg_rolloff = np.mean(rolloff_freqs) if rolloff_freqs else 0
    
    # Spectral bandwidth (spread around centroid)
    bandwidth = np.sqrt(np.sum((freqs[:, np.newaxis] - spectral_centroid)**2 * Sxx, axis=0) / (np.sum(Sxx, axis=0) + 1e-10))
    avg_bandwidth = np.mean(bandwidth)
    
    # Energy in different frequency bands
    low_freq_idx = np.searchsorted(freqs, 300)
    mid_freq_idx = np.searchsorted(freqs, 2000)
    high_freq_idx = np.searchsorted(freqs, 4000)
    
    low_energy = np.sum(Sxx[:low_freq_idx, :]) / (np.sum(Sxx) + 1e-10) * 100
    mid_energy = np.sum(Sxx[low_freq_idx:mid_freq_idx, :]) / (np.sum(Sxx) + 1e-10) * 100
    high_energy = np.sum(Sxx[mid_freq_idx:, :]) / (np.sum(Sxx) + 1e-10) * 100
    
    # Zero crossing rate (roughness/brightness)
    zcr = np.mean(np.abs(np.diff(np.sign(data)))) / 2
    
    # Energy envelope (dynamic range)
    frame_size = int(sr * 0.02)  # 20ms frames
    hop = int(sr * 0.01)
    energy_frames = []
    for i in range(0, len(data) - frame_size, hop):
        frame = data[i:i+frame_size]
        energy_frames.append(np.sqrt(np.mean(frame**2)))
    energy_frames = np.array(energy_frames)
    
    dynamic_range = 20 * np.log10(np.max(energy_frames) / (np.min(energy_frames[energy_frames > 0]) + 1e-10))
    
    # Rhythm analysis - detect speech rate
    # Find onset strength
    onset_env = np.zeros(len(times))
    for i in range(Sxx.shape[1]):
        # Sum high frequency energy change
        onset_env[i] = np.sum(Sxx[mid_freq_idx:, i])
    
    # Normalize
    onset_env = onset_env / (np.max(onset_env) + 1e-10)
    
    # Count onsets (peaks)
    peaks = ss.find_peaks(onset_env, height=0.3, distance=int(sr/1000))[0]
    speech_rate = len(peaks) / duration if duration > 0 else 0
    
    return {
        'duration': round(duration, 2),
        'sample_rate': sr,
        'fundamental_freq': round(f0, 1),
        'spectral_centroid': round(avg_centroid, 1),
        'spectral_rolloff': round(avg_rolloff, 1),
        'spectral_bandwidth': round(avg_bandwidth, 1),
        'energy_low_300hz': round(low_energy, 1),
        'energy_mid_2khz': round(mid_energy, 1),
        'energy_high_4khz': round(high_energy, 1),
        'zero_crossing_rate': round(float(zcr), 4),
        'rms_energy': round(float(rms), 4),
        'peak_amplitude': round(float(peak), 4),
        'dynamic_range_db': round(float(dynamic_range), 1),
        'speech_rate_onsets_per_sec': round(float(speech_rate), 2),
        'voice_characteristics': {
            'pitch': '高' if f0 > 180 else ('中' if f0 > 120 else '低'),
            'timbre': '明亮' if avg_centroid > 2000 else ('中性' if avg_centroid > 1500 else '厚重'),
            'rhythm': '快' if speech_rate > 3 else ('中' if speech_rate > 2 else '慢'),
            'power': '强' if rms > 0.1 else ('中' if rms > 0.05 else '柔和'),
        }
    }

def save_voice_profile(analysis, output_path):
    """Save voice analysis as a JSON profile."""
    # Convert numpy types to native
    clean = json.loads(json.dumps(analysis, default=str))
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(clean, f, ensure_ascii=False, indent=2)
    print(f"Voice profile saved to {output_path}")

if __name__ == '__main__':
    if len(sys.argv) < 2:
        # If no URL provided, check for recently downloaded AMR files
        import glob
        amr_files = sorted(glob.glob('/tmp/host_voice_*.amr'), key=os.path.getmtime, reverse=True)
        if amr_files:
            amr_path = amr_files[0]
            size = os.path.getsize(amr_path)
            if size < 100:  # Bad download
                print(f"File too small ({size} bytes), likely failed download: {amr_path}")
                sys.exit(1)
            print(f"Found: {amr_path} ({size} bytes)")
        else:
            print("No AMR files found. Pass a URL or file path.")
            sys.exit(1)
    else:
        url_or_path = sys.argv[1]
        if url_or_path.startswith('http'):
            amr_path = '/tmp/host_voice_live.amr'
            size = download_voice(url_or_path, amr_path)
            if size < 100:
                print("Download failed or file too small")
                sys.exit(1)
            print(f"Downloaded: {amr_path} ({size} bytes)")
        else:
            amr_path = url_or_path
    
    # Convert to WAV
    wav_path = '/tmp/host_voice_analysis.wav'
    if not convert_to_wav(amr_path, wav_path):
        print("Conversion failed")
        sys.exit(1)
    
    wav_size = os.path.getsize(wav_path)
    print(f"Converted to WAV: {wav_path} ({wav_size} bytes)")
    
    # Analyze
    analysis = analyze_voice(wav_path)
    
    # Save profile
    save_voice_profile(analysis, '/tmp/host_voice_profile.json')
    
    # Print summary
    print("\n" + "="*50)
    print("🎤 主播声音分析报告")
    print("="*50)
    print(f"时长: {analysis['duration']}秒")
    print(f"采样率: {analysis['sample_rate']}Hz")
    print(f"基频(音高): {analysis['fundamental_freq']}Hz ({analysis['voice_characteristics']['pitch']})")
    print(f"音色中心: {analysis['spectral_centroid']}Hz ({analysis['voice_characteristics']['timbre']})")
    print(f"频谱带宽: {analysis['spectral_bandwidth']}Hz")
    print(f"能量分布: 低频{analysis['energy_low_300hz']}% / 中频{analysis['energy_mid_2khz']}% / 高频{analysis['energy_high_4khz']}%")
    print(f"语速: {analysis['speech_rate_onsets_per_sec']}音节/秒 ({analysis['voice_characteristics']['rhythm']})")
    print(f"动态范围: {analysis['dynamic_range_db']}dB")
    print(f"力量感: {analysis['voice_characteristics']['power']}")
    print(f"综合特征: {analysis['voice_characteristics']}")
    print("="*50)
    
    # Also save WAV for future use
    import shutil
    shutil.copy(wav_path, '/tmp/host_voice_permanent.wav')
    print(f"\n永久保存: /tmp/host_voice_permanent.wav")
    print(f"分析文件: /tmp/host_voice_profile.json")
