import math
import nodes
import node_helpers
import torch
import torchaudio
import comfy.model_management
import comfy.utils
import numpy as np
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io

import scipy.signal
import scipy.ndimage
import scipy.fft
import scipy.sparse

# Audio Processing Functions - Derived from librosa (https://github.com/librosa/librosa)
# Copyright (c) 2013--2023, librosa development team.

def mel_to_hz(mels, htk=False):
    """Convert mel to Hz (slaney)"""
    mels = np.asanyarray(mels)
    if htk:
        return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = np.log(6.4) / 27.0
    if mels.ndim:
        log_t = mels >= min_log_mel
        freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
    elif mels >= min_log_mel:
        freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
    return freqs

def hz_to_mel(frequencies, htk=False):
    """Convert Hz to mel (slaney)"""
    frequencies = np.asanyarray(frequencies)
    if htk:
        return 2595.0 * np.log10(1.0 + frequencies / 700.0)
    f_min = 0.0
    f_sp = 200.0 / 3
    mels = (frequencies - f_min) / f_sp
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = np.log(6.4) / 27.0
    if frequencies.ndim:
        log_t = frequencies >= min_log_hz
        mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
    elif frequencies >= min_log_hz:
        mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
    return mels

def compute_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, bins_per_octave=12, tuning=0.0):
    """Compute Constant-Q Transform (CQT) spectrogram."""

    def _relative_bandwidth(freqs):
        bpo = np.empty_like(freqs)
        logf = np.log2(freqs)
        bpo[0] = 1.0 / (logf[1] - logf[0])
        bpo[-1] = 1.0 / (logf[-1] - logf[-2])
        bpo[1:-1] = 2.0 / (logf[2:] - logf[:-2])
        return (2.0 ** (2.0 / bpo) - 1.0) / (2.0 ** (2.0 / bpo) + 1.0)

    def _wavelet_lengths(freqs, sr, filter_scale, alpha):
        Q = float(filter_scale) / alpha
        return Q * sr / freqs  # shape (n_bins,) floats

    def _build_wavelet(freqs_oct, sr, filter_scale, alpha_oct):
        lengths = _wavelet_lengths(freqs_oct, sr, filter_scale, alpha_oct)
        filters = []
        for ilen, freq in zip(lengths, freqs_oct):
            t = np.arange(int(-ilen // 2), int(ilen // 2), dtype=float)
            sig = (np.cos(t * 2 * np.pi * freq / sr)
                   + 1j * np.sin(t * 2 * np.pi * freq / sr)).astype(np.complex64)
            sig *= scipy.signal.get_window('hann', len(sig), fftbins=True)
            l1 = np.sum(np.abs(sig))
            tiny = np.finfo(np.float32).tiny
            sig /= max(l1, tiny)
            filters.append(sig)
        max_len = max(lengths)
        n_fft = int(2.0 ** np.ceil(np.log2(max_len)))
        out = np.zeros((len(filters), n_fft), dtype=np.complex64)
        for k, f in enumerate(filters):
            lpad = int((n_fft - len(f)) // 2)
            out[k, lpad: lpad + len(f)] = f
        return out, lengths

    def _resample_half(y):
        ratio = 0.5
        n_samples = int(np.ceil(len(y) * ratio))
        # Kaiser-windowed FIR matches librosa/soxr more closely than scipy's default Hamming filter
        L = 2
        h = scipy.signal.firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
        y_hat = scipy.signal.resample_poly(y.astype(np.float32), 1, 2, window=h)
        if len(y_hat) > n_samples:
            y_hat = y_hat[:n_samples]
        elif len(y_hat) < n_samples:
            y_hat = np.pad(y_hat, (0, n_samples - len(y_hat)))
        y_hat /= np.sqrt(ratio)
        return y_hat.astype(np.float32)

    def _sparsify_rows(x, quantile=0.01):
        mags = np.abs(x)
        norms = np.sum(mags, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1.0, norms)
        mag_sort = np.sort(mags, axis=1)
        cumulative_mag = np.cumsum(mag_sort / norms, axis=1)
        threshold_idx = np.argmin(cumulative_mag < quantile, axis=1)
        x_sparse = scipy.sparse.lil_matrix(x.shape, dtype=x.dtype)
        for i, j in enumerate(threshold_idx):
            idx = np.where(mags[i] >= mag_sort[i, j])
            x_sparse[i, idx] = x[i, idx]
        return x_sparse.tocsr()

    if fmin is None:
        fmin = 32.70319566257483  # C1 note frequency

    fmin = fmin * (2.0 ** (tuning / bins_per_octave))
    freqs = fmin * (2.0 ** (np.arange(n_bins) / bins_per_octave))

    alpha = _relative_bandwidth(freqs)
    lengths = _wavelet_lengths(freqs, float(sr), 1, alpha)

    n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
    n_filters = min(bins_per_octave, n_bins)

    cqt_resp = []
    my_y = y.astype(np.float32)
    my_sr = float(sr)
    my_hop = int(hop_length)

    for i in range(n_octaves):
        if i == 0:
            sl = slice(-n_filters, None)
        else:
            sl = slice(-n_filters * (i + 1), -n_filters * i)

        freqs_oct = freqs[sl]
        alpha_oct = alpha[sl]

        basis, basis_lengths = _build_wavelet(freqs_oct, my_sr, 1, alpha_oct)
        n_fft_oct = basis.shape[1]

        # Frequency-domain normalisation
        basis = basis.astype(np.complex64)
        basis *= basis_lengths[:, np.newaxis] / float(n_fft_oct)
        fft_basis = scipy.fft.fft(basis, n=n_fft_oct, axis=1)[:, :(n_fft_oct // 2) + 1]
        fft_basis = _sparsify_rows(fft_basis, quantile=0.01)
        fft_basis = fft_basis * np.sqrt(sr / my_sr)

        y_pad = np.pad(my_y, int(n_fft_oct // 2), mode='constant')
        n_frames = 1 + (len(y_pad) - n_fft_oct) // my_hop
        frames = np.lib.stride_tricks.as_strided(
            y_pad,
            shape=(n_fft_oct, n_frames),
            strides=(y_pad.strides[0], y_pad.strides[0] * my_hop),
        )
        stft_result = scipy.fft.rfft(frames, axis=0)
        cqt_resp.append(fft_basis.dot(stft_result))

        if my_hop % 2 == 0:
            my_hop //= 2
            my_sr /= 2.0
            my_y = _resample_half(my_y)

    max_col = min(c.shape[-1] for c in cqt_resp)
    cqt_out = np.empty((n_bins, max_col), dtype=np.complex64)
    end = n_bins
    for c_i in cqt_resp:
        n_oct = c_i.shape[0]
        if end < n_oct:
            cqt_out[:end, :] = c_i[-end:, :max_col]
        else:
            cqt_out[end - n_oct:end, :] = c_i[:, :max_col]
        end -= n_oct

    cqt_out /= np.sqrt(lengths)[:, np.newaxis]
    return np.abs(cqt_out).astype(np.float32)


def cq_to_chroma_mapping(n_input, bins_per_octave=12, n_chroma=12, fmin=None):
    """Map CQT bins to chroma bins."""

    if fmin is None:
        fmin = 32.70319566257483  # C1 note frequency

    n_merge = bins_per_octave / n_chroma
    cq_to_ch = np.repeat(np.eye(n_chroma), int(n_merge), axis=1)
    cq_to_ch = np.roll(cq_to_ch, -int(n_merge // 2), axis=1)
    n_octaves = int(np.ceil(n_input / bins_per_octave))
    cq_to_ch = np.tile(cq_to_ch, n_octaves)[:, :n_input]

    midi_0 = np.mod(12 * np.log2(fmin / 440.0) + 69, 12)
    roll = int(np.round(midi_0 * (n_chroma / 12.0)))
    cq_to_ch = np.roll(cq_to_ch, roll, axis=0)

    return cq_to_ch.astype(np.float32)


def _parabolic_interpolation(S, axis=-2):
    """Compute parabolic interpolation shift for peak refinement."""
    S_next = np.roll(S, -1, axis=axis)
    S_prev = np.roll(S, 1, axis=axis)

    a = S_next + S_prev - 2 * S
    b = (S_next - S_prev) / 2.0

    shifts = np.zeros_like(S)
    valid = np.abs(b) < np.abs(a)
    shifts[valid] = -b[valid] / a[valid]

    if axis == -2 or axis == S.ndim - 2:
        shifts[0, :] = 0
        shifts[-1, :] = 0
    elif axis == 0:
        shifts[0, ...] = 0
        shifts[-1, ...] = 0

    return shifts


def _localmax(S, axis=-2):
    """Find local maxima along an axis."""

    S_prev = np.roll(S, 1, axis=axis)
    S_next = np.roll(S, -1, axis=axis)

    local_max = (S > S_prev) & (S >= S_next)

    if axis == -2 or axis == S.ndim - 2:
        local_max[-1, :] = S[-1, :] > S[-2, :]
        # First element is never a local max (strict inequality with previous)
        local_max[0, :] = False
    elif axis == 0:
        local_max[-1, ...] = S[-1, ...] > S[-2, ...]
        local_max[0, ...] = False

    return local_max


def piptrack(y=None, sr=22050, S=None, n_fft=2048, hop_length=512,
             fmin=150.0, fmax=4000.0, threshold=0.1):
    """Pitch tracking on thresholded parabolically-interpolated STFT."""

    # Compute STFT if not provided
    if S is None:
        if y is None:
            raise ValueError("Either y or S must be provided")

        fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
        if len(fft_window) < n_fft:
            lpad = int((n_fft - len(fft_window)) // 2)
            fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
        fft_window = fft_window.reshape((-1, 1))

        y_pad = np.pad(y, int(n_fft // 2), mode='constant')
        n_frames = 1 + (len(y_pad) - n_fft) // hop_length
        frames = np.lib.stride_tricks.as_strided(
            y_pad,
            shape=(n_fft, n_frames),
            strides=(y_pad.strides[0], y_pad.strides[0] * hop_length)
        )

        S = scipy.fft.rfft((fft_window * frames).astype(np.float32), axis=0)

    S = np.abs(S)

    fmin = max(fmin, 0)
    fmax = min(fmax, float(sr) / 2)

    fft_freqs = np.fft.rfftfreq(S.shape[0] * 2 - 2, 1.0 / sr)
    if len(fft_freqs) > S.shape[0]:
        fft_freqs = fft_freqs[:S.shape[0]]

    shift = _parabolic_interpolation(S, axis=0)
    avg = np.gradient(S, axis=0)
    dskew = 0.5 * avg * shift

    pitches = np.zeros_like(S)
    mags = np.zeros_like(S)

    freq_mask = (fmin <= fft_freqs) & (fft_freqs < fmax)
    freq_mask = freq_mask.reshape(-1, 1)

    ref_value = threshold * np.max(S, axis=0, keepdims=True)
    local_max = _localmax(S * (S > ref_value), axis=0)
    idx = np.nonzero(freq_mask & local_max)

    pitches[idx] = (idx[0] + shift[idx]) * float(sr) / (S.shape[0] * 2 - 2)
    mags[idx] = S[idx] + dskew[idx]

    return pitches, mags


def hz_to_octs(frequencies, tuning=0.0, bins_per_octave=12):
    """Convert frequencies (Hz) to octave numbers."""

    A440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
    octs = np.log2(np.asanyarray(frequencies) / (float(A440) / 16))
    return octs


def pitch_tuning(frequencies, resolution=0.01, bins_per_octave=12):
    """Estimate tuning offset from a collection of pitches."""

    frequencies = np.atleast_1d(frequencies)
    frequencies = frequencies[frequencies > 0]

    if not np.any(frequencies):
        return 0.0

    residual = np.mod(bins_per_octave * hz_to_octs(frequencies, tuning=0.0,
                                                     bins_per_octave=bins_per_octave), 1.0)
    residual[residual >= 0.5] -= 1.0

    bins = np.linspace(-0.5, 0.5, int(np.ceil(1.0 / resolution)) + 1)
    counts, tuning = np.histogram(residual, bins)
    tuning_est = tuning[np.argmax(counts)]
    return tuning_est


def estimate_tuning(y, sr=22050, bins_per_octave=12):
    """Estimate global tuning deviation from 12-TET."""
    n_fft = 2048
    hop_length = 512

    if len(y) < n_fft:
        return 0.0

    pitch, mag = piptrack(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length,
                          fmin=150.0, fmax=4000.0, threshold=0.1)

    pitch_mask = pitch > 0

    if not pitch_mask.any():
        return 0.0

    threshold = np.median(mag[pitch_mask])
    valid_pitches = pitch[(mag >= threshold) & pitch_mask]

    if len(valid_pitches) == 0:
        return 0.0

    tuning = pitch_tuning(valid_pitches, resolution=0.01, bins_per_octave=bins_per_octave)

    return float(tuning)


def compute_chroma_cens(y, sr=22050, hop_length=512, n_chroma=12,
                       n_octaves=7, bins_per_octave=36,
                       win_len_smooth=41, norm=2):
    """Compute Chroma Energy Normalized Statistics (CENS) features."""

    tuning = estimate_tuning(y, sr, bins_per_octave=bins_per_octave)

    fmin = 32.70319566257483  # C1 note frequency
    n_bins = n_octaves * bins_per_octave
    cqt_mag = compute_cqt(y, sr=sr, hop_length=hop_length,
                         fmin=fmin, n_bins=n_bins,
                         bins_per_octave=bins_per_octave,
                         tuning=tuning)

    chroma_map = cq_to_chroma_mapping(n_bins, bins_per_octave=bins_per_octave,
                                     n_chroma=n_chroma, fmin=fmin)
    chroma = np.dot(chroma_map, cqt_mag)

    threshold = np.finfo(chroma.dtype).tiny
    chroma_sum = np.sum(np.abs(chroma), axis=0, keepdims=True)
    chroma_sum = np.maximum(chroma_sum, threshold)
    chroma = chroma / chroma_sum

    quant_steps = [0.4, 0.2, 0.1, 0.05]
    quant_weights = [0.25, 0.25, 0.25, 0.25]
    chroma_quant = np.zeros_like(chroma)
    for step, weight in zip(quant_steps, quant_weights):
        chroma_quant += (chroma > step) * weight

    if win_len_smooth is not None and win_len_smooth > 0:
        win = scipy.signal.get_window('hann', win_len_smooth + 2, fftbins=False)
        win /= np.sum(win)
        win = win.reshape(1, -1)
        chroma_smooth = scipy.ndimage.convolve(chroma_quant, win, mode='constant')
    else:
        chroma_smooth = chroma_quant

    if norm == 2:
        threshold = np.finfo(chroma_smooth.dtype).tiny
        chroma_norm = np.sqrt(np.sum(chroma_smooth ** 2, axis=0, keepdims=True))
        chroma_norm = np.maximum(chroma_norm, threshold)
        chroma_smooth = chroma_smooth / chroma_norm
    elif norm == np.inf:
        threshold = np.finfo(chroma_smooth.dtype).tiny
        chroma_norm = np.max(np.abs(chroma_smooth), axis=0, keepdims=True)
        chroma_norm = np.maximum(chroma_norm, threshold)
        chroma_smooth = chroma_smooth / chroma_norm

    return chroma_smooth


def _create_mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None):
    """Create mel-scale filterbank matrix."""
    if fmax is None:
        fmax = sr / 2.0
    mel_basis = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=np.float32)
    fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
    min_mel = hz_to_mel(fmin)
    max_mel = hz_to_mel(fmax)
    mels = np.linspace(min_mel, max_mel, n_mels + 2)
    mel_f = mel_to_hz(mels)
    fdiff = np.diff(mel_f)
    ramps = np.subtract.outer(mel_f, fftfreqs)

    for i in range(n_mels):
        lower = -ramps[i] / fdiff[i]
        upper = ramps[i + 2] / fdiff[i + 1]
        mel_basis[i] = np.maximum(0, np.minimum(lower, upper))

    enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
    mel_basis *= enorm[:, np.newaxis]
    return mel_basis


def _compute_mel_spectrogram(data, sr, n_fft=2048, hop_length=512, n_mels=128):
    """Compute mel spectrogram from audio signal."""
    fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
    if len(fft_window) < n_fft:
        lpad = int((n_fft - len(fft_window)) // 2)
        fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')

    fft_window = fft_window.reshape((-1, 1))
    data_padded = np.pad(data, int(n_fft // 2), mode='constant')
    n_frames = 1 + (len(data_padded) - n_fft) // hop_length
    shape = (n_fft, n_frames)
    strides = (data_padded.strides[0], data_padded.strides[0] * hop_length)
    frames = np.lib.stride_tricks.as_strided(data_padded, shape=shape, strides=strides)

    stft_result = scipy.fft.rfft(fft_window * frames, axis=0).astype(np.complex64)
    power_spec = np.abs(stft_result) ** 2

    mel_basis = _create_mel_filterbank(sr, n_fft, n_mels=n_mels, fmin=0.0, fmax=sr / 2.0)
    mel_spec = np.dot(mel_basis, power_spec)
    return mel_spec.astype(np.float32)


def quick_tempo_estimate(audio_np, sr, start_bpm=120.0, std_bpm=1.0, hop_length=512):
    """Estimate tempo using autocorrelation tempogram."""

    if len(audio_np) < hop_length * 10:
        logging.warning("Audio too short for tempo estimation, returning default BPM of 120.0")
        return 120.0

    n_fft = 2048
    mel_S = _compute_mel_spectrogram(audio_np, sr, n_fft=n_fft, hop_length=hop_length, n_mels=128)
    log_mel_S = 10.0 * np.log10(np.maximum(1e-10, mel_S))

    lag = 1
    S_diff = log_mel_S[:, lag:] - log_mel_S[:, :-lag]
    S_onset = np.maximum(0.0, S_diff)
    onset_env_pre = np.mean(S_onset, axis=0)
    pad_width = lag + n_fft // (2 * hop_length)
    onset_env = np.pad(onset_env_pre, (pad_width, 0), mode='constant')
    onset_env = onset_env[:mel_S.shape[1]]

    return estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm, std_bpm, max_tempo=320.0)


def estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm=120.0, std_bpm=1.0, max_tempo=320.0):
    """Estimate tempo from onset strength envelope using autocorrelation tempogram."""
    if len(onset_env) < 20:
        return 120.0

    ac_size = 8.0
    win_length = int(np.round(ac_size * sr / hop_length))
    win_length = min(win_length, len(onset_env))

    pad_width = win_length // 2
    onset_padded = np.pad(onset_env, (pad_width, pad_width), mode='linear_ramp', end_values=(0, 0))

    n_frames = len(onset_env)
    shape = (win_length, n_frames)
    strides = (onset_padded.strides[0], onset_padded.strides[0])
    frames = np.lib.stride_tricks.as_strided(onset_padded, shape=shape, strides=strides)

    hann_window = scipy.signal.get_window('hann', win_length, fftbins=True)
    windowed_frames = frames * hann_window[:, np.newaxis]

    tempogram = np.zeros((win_length, n_frames))
    for i in range(n_frames):
        frame = windowed_frames[:, i]
        n_pad = scipy.fft.next_fast_len(2 * len(frame) - 1)
        fft_result = scipy.fft.rfft(frame, n=n_pad)
        powspec = np.abs(fft_result) ** 2
        ac = scipy.fft.irfft(powspec, n=n_pad)
        tempogram[:, i] = ac[:win_length]

    ac_max = np.max(np.abs(tempogram), axis=0)
    mask = ac_max > 0
    tempogram[:, mask] /= ac_max[mask]

    tempogram_mean = np.mean(tempogram, axis=1)
    tempogram_mean = np.maximum(tempogram_mean, 0)

    bpms = np.zeros(win_length, dtype=np.float64)
    bpms[0] = np.inf
    bpms[1:] = 60.0 * sr / (hop_length * np.arange(1.0, win_length))

    logprior = -0.5 * ((np.log2(bpms) - np.log2(start_bpm)) / std_bpm) ** 2

    if max_tempo is not None:
        max_idx = int(np.argmax(bpms < max_tempo))
        if max_idx > 0:
            logprior[:max_idx] = -np.inf

    weighted = np.log1p(1e6 * tempogram_mean) + logprior
    best_idx = int(np.argmax(weighted[1:])) + 1
    tempo = bpms[best_idx]

    return tempo


def detect_onset_peaks(onset_env, sr=22050, hop_length=512, pre_max=0.03, post_max=0.0,
                      pre_avg=0.10, post_avg=0.10, wait=0.03, delta=0.07):
    """Detect onset peaks using peak picking algorithm."""

    onset_normalized = onset_env - np.min(onset_env)
    onset_max = np.max(onset_normalized)
    if onset_max > 0:
        onset_normalized = onset_normalized / onset_max

    pre_max_frames = int(pre_max * sr / hop_length)
    post_max_frames = int(post_max * sr / hop_length) + 1
    pre_avg_frames = int(pre_avg * sr / hop_length)
    post_avg_frames = int(post_avg * sr / hop_length) + 1
    wait_frames = int(wait * sr / hop_length)

    peaks = np.zeros(len(onset_normalized), dtype=bool)
    peaks[0] = (onset_normalized[0] >= np.max(onset_normalized[:min(post_max_frames, len(onset_normalized))]))
    peaks[0] &= (onset_normalized[0] >= np.mean(onset_normalized[:min(post_avg_frames, len(onset_normalized))]) + delta)

    if peaks[0]:
        n = wait_frames + 1
    else:
        n = 1

    while n < len(onset_normalized):
        maxn = np.max(onset_normalized[max(0, n - pre_max_frames):min(n + post_max_frames, len(onset_normalized))])
        peaks[n] = (onset_normalized[n] == maxn)

        if not peaks[n]:
            n += 1
            continue

        avgn = np.mean(onset_normalized[max(0, n - pre_avg_frames):min(n + post_avg_frames, len(onset_normalized))])
        peaks[n] &= (onset_normalized[n] >= avgn + delta)

        if not peaks[n]:
            n += 1
            continue

        n += wait_frames + 1

    return np.flatnonzero(peaks).astype(np.int32)


def track_beats(onset_env, tempo, sr, hop_length, tightness=100, trim=True):
    """Track beats using dynamic programming."""

    frame_rate = sr / hop_length
    frames_per_beat = np.round(frame_rate * 60.0 / tempo)

    if frames_per_beat <= 0 or len(onset_env) < 2:
        return np.array([], dtype=np.int32)

    onset_std = np.std(onset_env, ddof=1)
    if onset_std > 0:
        onset_normalized = onset_env / onset_std
    else:
        onset_normalized = onset_env

    window_range = np.arange(-frames_per_beat, frames_per_beat + 1)
    window = np.exp(-0.5 * (window_range * 32.0 / frames_per_beat) ** 2)

    localscore = scipy.signal.convolve(onset_normalized, window, mode='same')

    backlink = np.full(len(localscore), -1, dtype=np.int32)
    cumscore = np.zeros(len(localscore), dtype=np.float64)

    score_thresh = 0.01 * localscore.max()
    first_beat = True

    backlink[0] = -1
    cumscore[0] = localscore[0]

    fpb = int(frames_per_beat)

    for i in range(1, len(localscore)):
        score_i = localscore[i]
        best_score = -np.inf
        beat_location = -1

        search_start = int(i - np.round(fpb / 2.0))
        search_end = int(i - 2 * fpb - 1)

        for loc in range(search_start, search_end, -1):
            if loc < 0:
                break

            score = cumscore[loc] - tightness * (np.log(i - loc) - np.log(fpb)) ** 2

            if score > best_score:
                best_score = score
                beat_location = loc

        if beat_location >= 0:
            cumscore[i] = score_i + best_score
        else:
            cumscore[i] = score_i

        if first_beat and score_i < score_thresh:
            backlink[i] = -1
        else:
            backlink[i] = beat_location
            first_beat = False

    local_max_mask = np.zeros(len(cumscore), dtype=bool)

    local_max_mask[0] = False

    for i in range(1, len(cumscore) - 1):
        local_max_mask[i] = (cumscore[i] > cumscore[i-1]) and (cumscore[i] >= cumscore[i+1])

    if len(cumscore) > 1:
        local_max_mask[-1] = cumscore[-1] > cumscore[-2]

    if np.any(local_max_mask):
        median_max = np.median(cumscore[local_max_mask])
        threshold = 0.5 * median_max

        tail = -1
        for i in range(len(cumscore) - 1, -1, -1):
            if local_max_mask[i] and cumscore[i] >= threshold:
                tail = i
                break
    else:
        tail = len(cumscore) - 1

    beats = np.zeros(len(localscore), dtype=bool)
    n = tail
    visited = set()
    while n >= 0 and n not in visited:
        beats[n] = True
        visited.add(n)
        n = backlink[n]

    if trim and np.any(beats):
        beat_positions = np.flatnonzero(beats)

        beat_localscores = localscore[beat_positions]

        w = np.hanning(5)
        smooth_boe_full = np.convolve(beat_localscores, w)
        smooth_boe = smooth_boe_full[len(w)//2 : len(localscore) + len(w)//2]

        threshold = 0.5 * np.sqrt(np.mean(smooth_boe ** 2))

        start_frame = 0
        while start_frame < len(localscore) and localscore[start_frame] <= threshold:
            beats[start_frame] = False
            start_frame += 1

        end_frame = len(localscore) - 1
        while end_frame >= 0 and localscore[end_frame] <= threshold:
            beats[end_frame] = False
            end_frame -= 1

    return np.flatnonzero(beats).astype(np.int32)

def compute_onset_envelope(mel_spec_db, n_fft=2048, hop_length=512):
    """Compute onset strength envelope from a log-mel spectrogram (dB)."""
    lag = 1
    onset_diff = mel_spec_db[:, lag:] - mel_spec_db[:, :-lag]
    onset_diff = np.maximum(0.0, onset_diff)
    envelope_pre_pad = np.mean(onset_diff, axis=0)

    pad_width = lag + n_fft // (2 * hop_length)
    envelope = np.pad(envelope_pre_pad, (pad_width, 0), mode='constant')
    envelope = envelope[:mel_spec_db.shape[1]]

    return envelope

def compute_mfcc(mel_spec_db, n_mfcc=20):
    """Compute MFCC features from a log-mel spectrogram (dB)."""
    mfcc = scipy.fft.dct(mel_spec_db, axis=0, type=2, norm='ortho')[:n_mfcc].T
    return mfcc.astype(np.float32)


def power_to_db(S, amin=1e-10, top_db=80.0, ref=1.0):
    """Convert a power spectrogram (amplitude squared) to decibel (dB) units"""
    S = np.asarray(S)
    log_spec = 10.0 * np.log10(np.maximum(amin, S))
    log_spec -= 10.0 * np.log10(np.maximum(amin, ref))
    if top_db is not None:
        log_spec = np.maximum(log_spec, log_spec.max() - top_db)
    return log_spec


class WanDancerEncodeAudio(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanDancerEncodeAudio",
            category="model/conditioning/video_models",
            inputs=[
                io.Audio.Input("audio"),
                io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4),
                io.Float.Input("audio_inject_scale", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="The scale for the audio features when injected into the video model."),
            ],
            outputs=[
                io.AudioEncoderOutput.Output(display_name="audio_encoder_output"),
                io.String.Output(display_name="fps_string", tooltip="The calculated fps based on the audio length and the number of video frames. Used in the prompt."),
            ],
        )

    @classmethod
    def execute(cls, video_frames, audio_inject_scale, audio) -> io.NodeOutput:
        waveform = audio["waveform"][0]
        sample_rate = audio["sample_rate"]
        base_fps = 30
        hop_length = 512
        model_sr = 22050
        n_fft = 2048

        # start tempo from original audio (not the resampled one) to match the reference pipeline
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=False)

        start_bpm = quick_tempo_estimate(waveform.squeeze().cpu().numpy(), sample_rate, hop_length=hop_length)

        # resample to the sample rate used for feature extraction
        resample_sr = base_fps * hop_length
        waveform = torchaudio.functional.resample(waveform, sample_rate, resample_sr)

        waveform_np = waveform.cpu().numpy().squeeze()
        mel_spec = _compute_mel_spectrogram(waveform_np, model_sr, n_fft, hop_length, n_mels=128)
        mel_spec_db = power_to_db(mel_spec, amin=1e-10, top_db=80.0, ref=1.0)
        envelope = compute_onset_envelope(mel_spec_db, n_fft, hop_length)
        mfcc = compute_mfcc(mel_spec_db, n_mfcc=20)
        chroma = compute_chroma_cens(y=waveform_np, sr=model_sr, hop_length=hop_length).T
        # detect peaks
        peak_idxs = detect_onset_peaks(envelope, sr=model_sr, hop_length=hop_length)
        peak_onehot = np.zeros_like(envelope, dtype=np.float32)
        peak_onehot[peak_idxs] = 1.0
        # detect beats
        beat_tracking_tempo = estimate_tempo_from_onset(envelope, sr=model_sr, hop_length=hop_length, start_bpm=start_bpm)
        beat_idxs = track_beats(envelope, beat_tracking_tempo, model_sr, hop_length, tightness=100, trim=True)
        beat_onehot = np.zeros_like(envelope, dtype=np.float32)
        beat_onehot[beat_idxs] = 1.0

        audio_feature = np.concatenate(
            [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
            axis=-1,
        )
        audio_feature = torch.from_numpy(audio_feature).unsqueeze(0).to(comfy.model_management.intermediate_device())

        fps = float(base_fps / int(audio_feature.shape[1] / video_frames + 0.5))

        audio_encoder_output = {
            "audio_feature": audio_feature,
            "fps": fps,
            "audio_inject_scale": audio_inject_scale,
        }

        if int(fps + 0.5) != 30:
            fps_string = " 帧率是{:.4f}".format(fps) # "frame rate is" in Chinese, as it was in the original pipeline
        else:
            fps_string = ", 帧率是30fps。" # to match the reference pipeline when the fps is 30

        return io.NodeOutput(audio_encoder_output, fps_string)


class WanDancerVideo(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanDancerVideo",
            category="model/conditioning/video_models",
            inputs=[
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("negative"),
                io.Vae.Input("vae"),
                io.Int.Input("width", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
                io.Int.Input("height", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
                io.Int.Input("length", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="The number of frames in the generated video. Should stay 149 for WanDancer."),
                io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="The CLIP vision embeds for the first frame."),
                io.ClipVisionOutput.Input("clip_vision_output_ref", optional=True, tooltip="The CLIP vision embeds for the reference image."),
                io.Image.Input("start_image", optional=True, tooltip="The initial image(s) to be encoded, can be any number of frames."),
                io.Mask.Input("mask", optional=True, tooltip="Image conditioning mask for the start image(s). White is kept, black is generated. Used for the local generations."),
                io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
            ],
            outputs=[
                io.Conditioning.Output(display_name="positive"),
                io.Conditioning.Output(display_name="negative"),
                io.Latent.Output(display_name="latent", tooltip="Empty latent."),
            ],
        )

    @classmethod
    def execute(cls, positive, negative, vae, width, height, length, start_image=None, mask=None, clip_vision_output=None, clip_vision_output_ref=None, audio_encoder_output=None) -> io.NodeOutput:
        latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
        if start_image is not None:
            start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
            image = torch.zeros((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
            image[:start_image.shape[0]] = start_image

            concat_latent_image = vae.encode(image[:, :, :, :3])
            if mask is None:
                concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
                concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
            else:
                concat_mask = 1 - mask[:length].unsqueeze(0)
                concat_mask = comfy.utils.common_upscale(concat_mask, concat_latent_image.shape[-2], concat_latent_image.shape[-1], "nearest-exact", "disabled")
                concat_mask = torch.cat([torch.repeat_interleave(concat_mask[:, 0:1], repeats=4, dim=1), concat_mask[:, 1:]], dim=1)
                concat_mask = concat_mask.view(1, concat_mask.shape[1] // 4, 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]).transpose(1, 2)

            positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
            negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})

        if clip_vision_output is not None:
            positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
            negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})

        if audio_encoder_output is not None:
            positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
            negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})

        out_latent = {}
        out_latent["samples"] = latent
        return io.NodeOutput(positive, negative, out_latent)


class WanDancerPadKeyframes(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanDancerPadKeyframes",
            category="image/video",
            inputs=[
                io.Image.Input("images",),
                io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of this segment (usually 149 frames)"),
                io.Int.Input("segment_index", default=0, min=0, max=100, tooltip="Which segment this is (0 for first, 1 for second, etc.)"),
                io.Audio.Input("audio", tooltip="Audio to calculate total output frames from and extract segment audio."),
            ],
            outputs=[
                io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequence"),
                io.Mask.Output(display_name="keyframes_mask", tooltip="Mask indicating valid frames"),
                io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for this video segment"),
            ],
        )

    @classmethod
    def do_execute(cls, images, segment_length, segment_index, audio):
        B, H, W, C = images.shape
        fps = 30

        # calculate total frames
        audio_duration = audio["waveform"].shape[-1] / audio["sample_rate"]
        segment_duration = segment_length / fps
        buffer = 0.2
        num_segments = int((audio_duration - buffer) / segment_duration) + 1 if audio_duration > buffer else 0
        total_frames = num_segments * segment_length

        mask = torch.zeros((segment_length, H, W), device=images.device, dtype=images.dtype)
        keyframes = torch.zeros((segment_length, H, W, C), dtype=images.dtype, device=images.device)

        # guard: with no audio or no images, nothing to place — leave keyframes/mask zeroed
        if total_frames > 0 and B > 0:
            frame_interval = float(total_frames) / B
            seg_num = int(math.ceil(total_frames / segment_length))
            is_last_segment = (segment_index == seg_num - 1)

            positions = []
            images_before_this_segment = 0

            # count images consumed by previous segments
            for seg_idx in range(segment_index):
                end_idx = (total_frames - segment_length * seg_idx - 1) if seg_idx == seg_num - 1 else (segment_length - 1)
                cnt = 0
                while cnt * frame_interval < end_idx - frame_interval:
                    cnt += 1
                images_before_this_segment += cnt

            # positions for current segment
            end_index = (total_frames - segment_length * segment_index - 1) if is_last_segment else (segment_length - 1)
            cnt = 0
            while cnt * frame_interval < end_index - frame_interval:
                pos = int(math.ceil(frame_interval * cnt))
                positions.append((pos, images_before_this_segment + cnt))
                cnt += 1
            positions.append((end_index, images_before_this_segment + cnt))

            valid_positions = [(pos, idx) for pos, idx in positions if idx < B and pos < segment_length]

            if valid_positions:
                seg_positions, img_indices = zip(*valid_positions)
                seg_positions = torch.tensor(seg_positions, dtype=torch.long, device=images.device)
                img_indices = torch.tensor(img_indices, dtype=torch.long, device=images.device)
                mask[seg_positions] = 1
                keyframes[seg_positions] = images[img_indices]

        # extract audio segment
        segment_duration = segment_length / fps
        start_time = segment_index * segment_duration
        end_time = min(start_time + segment_duration, audio_duration)

        sample_rate = audio["sample_rate"]
        start_sample = int(start_time * sample_rate)
        end_sample = int(end_time * sample_rate)

        audio_segment_waveform = audio["waveform"][:, :, start_sample:end_sample]
        audio_segment = {
            "waveform": audio_segment_waveform,
            "sample_rate": sample_rate
        }

        return keyframes, mask, audio_segment

    @classmethod
    def execute(cls, images, segment_length, segment_index, audio=None) -> io.NodeOutput:
        return io.NodeOutput(*cls.do_execute(images, segment_length, segment_index, audio))

class WanDancerPadKeyframesList(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="WanDancerPadKeyframesList",
            category="image/video",
            inputs=[
                io.Image.Input("images"),
                io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of each segment (usually 149 frames)"),
                io.Int.Input("num_segments", default=1, min=1, max=100, tooltip="How many padded segments to emit as lists."),
                io.Audio.Input("audio", tooltip="Audio to slice for each emitted segment."),
            ],
            outputs=[
                io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequences", is_output_list=True),
                io.Mask.Output(display_name="keyframes_mask", tooltip="Masks indicating valid frames", is_output_list=True),
                io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for each video segment", is_output_list=True),
            ],
        )

    @classmethod
    def execute(cls, images, segment_length, num_segments, audio=None) -> io.NodeOutput:
        outputs = [WanDancerPadKeyframes.do_execute(images, segment_length, i, audio) for i in range(num_segments)]
        keyframes, masks, audio_segments = zip(*outputs)
        return io.NodeOutput(list(keyframes), list(masks), list(audio_segments))

class WanDancerExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            WanDancerVideo,
            WanDancerEncodeAudio,
            WanDancerPadKeyframes,
            WanDancerPadKeyframesList,
        ]

async def comfy_entrypoint() -> WanDancerExtension:
    return WanDancerExtension()
