"""GPT-OSS text encoder for Lens."""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope


@dataclass
class GptOss20BConfig:
    vocab_size: int = 201088
    hidden_size: int = 2880
    intermediate_size: int = 2880
    num_hidden_layers: int = 24
    num_attention_heads: int = 64
    num_key_value_heads: int = 8
    head_dim: int = 64
    num_local_experts: int = 32
    num_experts_per_tok: int = 4
    sliding_window: int = 128
    original_max_position_embeddings: int = 4096
    rope_theta: float = 150000.0
    rope_factor: float = 32.0
    rope_beta_fast: float = 32.0
    rope_beta_slow: float = 1.0
    rope_truncate: bool = False
    rms_norm_eps: float = 1e-5
    attention_bias: bool = True
    layer_types: Optional[List[str]] = None
    moe_alpha: float = 1.702
    moe_limit: float = 7.0

    def __post_init__(self):
        if self.layer_types is None:
            self.layer_types = [
                "sliding_attention" if (i + 1) % 2 else "full_attention"
                for i in range(self.num_hidden_layers)
            ]


def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
    original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
    """YARN inv_freq + attention scaling (matches transformers)."""
    dim = head_dim

    def find_correction_dim(num_rotations: float) -> float:
        return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
            2 * math.log(base)
        )

    def find_correction_range() -> tuple[float, float]:
        low = find_correction_dim(beta_fast)
        high = find_correction_dim(beta_slow)
        if truncate:
            low = math.floor(low)
            high = math.ceil(high)
        return max(low, 0), min(high, dim - 1)

    def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
        if min_ == max_:
            max_ += 0.001
        linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
        return torch.clamp(linear, 0, 1)

    def get_mscale(scale: float) -> float:
        if scale <= 1:
            return 1.0
        return 0.1 * math.log(scale) + 1.0

    attention_scaling = get_mscale(factor)

    pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (factor * pos_freqs)

    low, high = find_correction_range()
    extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
    inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
    return inv_freq, attention_scaling


def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
    pos_e = position_ids[:, None, :].float()
    freqs = (inv_freq_e @ pos_e).transpose(1, 2)
    emb = torch.cat((freqs, freqs), dim=-1)
    cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
    sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
    sin_split = sin.shape[-1] // 2
    return cos, sin[..., :sin_split], -sin[..., sin_split:]


def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
    attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
    """Attention with per-head sinks.

    Sinks add a learned term to each row's softmax denominator but contribute
    nothing to the output. We fake this by appending one zero k/v position and
    putting the sink logit in the mask at that column.
    """

    if num_kv_groups > 1 and not TORCH_HAS_GQA:
        k = k.repeat_interleave(num_kv_groups, dim=1)
        v = v.repeat_interleave(num_kv_groups, dim=1)

    B, _, S_q, D = q.shape
    H_kv = k.shape[1]
    S_kv = k.shape[-2]

    k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
    v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)

    sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
    if attention_mask is not None:
        mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
    else:
        mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
    mask = torch.cat([mask_left, sinks_col], dim=-1)

    op = optimized_attention_for_device(q.device, mask=True, small_input=True)
    return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)


class GptOssAttention(nn.Module):
    def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
        super().__init__()
        self.layer_idx = layer_idx
        self.layer_type = config.layer_types[layer_idx]
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.num_kv_groups = self.num_heads // self.num_kv_heads
        self.head_dim = config.head_dim
        self.hidden_size = config.hidden_size
        self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None

        bias = config.attention_bias
        self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
        self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
        self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
        self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
        self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
        B, S, _ = hidden_states.shape

        q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)

        q, k = apply_rope(q, k, freqs_cis)

        out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
        return self.o_proj(out)


# Mixture of Experts

class GptOssTopKRouter(nn.Module):
    def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
        super().__init__()
        self.top_k = config.num_experts_per_tok
        self.num_experts = config.num_local_experts
        self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
        self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))

    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
        bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
        logits = F.linear(hidden_states, weight, bias)
        top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
        # Softmax over top-k slice only
        scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
        return scores, top_idx


class GptOssExperts(nn.Module):
    def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
        super().__init__()
        self.num_experts = config.num_local_experts
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.alpha = config.moe_alpha
        self.limit = config.moe_limit

        E = self.num_experts
        H = self.hidden_size
        I = self.intermediate_size

        self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
        self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)

    def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
        gate = gate_up[..., ::2]
        up = gate_up[..., 1::2]
        gate = gate.clamp(max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        return torch.addcmul(glu, up, glu)

    def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
        N = hidden_states.shape[0]
        top_k = router_indices.shape[-1]
        H = hidden_states.shape[-1]

        per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)

        expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
        expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

        with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
             self.down_proj.bank_resident(hidden_states) as down_bank:
            for ei in expert_hit:
                expert_idx = int(ei.item())
                top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
                current = hidden_states[token_idx]

                gate_up = gate_up_bank.expert_linear(current, expert_idx)
                gated = self._apply_gate(gate_up)
                expert_out = down_bank.expert_linear(gated, expert_idx)

                weighted = expert_out * routing_weights[token_idx, top_k_pos, None]

                flat_idx = token_idx * top_k + top_k_pos
                per_pair[flat_idx] = weighted.to(per_pair.dtype)

        return per_pair.view(N, top_k, H).sum(dim=1)


class GptOssMLP(nn.Module):
    def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
        super().__init__()
        self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
        self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        B, S, H = hidden_states.shape
        flat = hidden_states.reshape(-1, H)
        scores, idx = self.router(flat)
        out = self.experts(flat, idx, scores)
        return out.reshape(B, S, H)


# Decoder layer + model

class GptOssDecoderLayer(nn.Module):
    def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
        super().__init__()
        self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
        self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
        self.layer_type = config.layer_types[layer_idx]

    def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
        x = residual + x

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = residual + x
        return x


def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
    neg = torch.finfo(dtype).min
    mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
    mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
    if key_padding_mask is not None:
        kp = key_padding_mask.to(dtype=dtype)
        kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
        mask = mask + kp
    return mask


def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
    neg = torch.finfo(dtype).min
    i = torch.arange(S, device=device).view(-1, 1)
    j = torch.arange(S, device=device).view(1, -1)
    keep = (j <= i) & (j > i - window)
    mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
    mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
    if key_padding_mask is not None:
        kp = key_padding_mask.to(dtype=dtype)
        kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
        mask = mask + kp
    return mask


class GptOssModel(nn.Module):
    """GPT-OSS decoder with multi-layer hidden-state capture + early exit."""

    def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
        super().__init__()
        self.config = config
        self.dtype = dtype
        self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
        self.layers = nn.ModuleList(
            [
                GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)

        # Always build on CPU so the buffer survives meta-device construction.
        inv_freq, attn_scaling = _yarn_inv_freq(
            head_dim=config.head_dim,
            base=config.rope_theta,
            factor=config.rope_factor,
            beta_fast=config.rope_beta_fast,
            beta_slow=config.rope_beta_slow,
            original_max_position_embeddings=config.original_max_position_embeddings,
            truncate=config.rope_truncate,
            device=torch.device("cpu"),
        )
        self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
        self.rope_attention_scaling = float(attn_scaling)

    @property
    def num_layers(self) -> int:
        return self.config.num_hidden_layers

    def get_input_embeddings(self):
        return self.embed_tokens

    def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
    ) -> dict[str, torch.Tensor]:
        full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
        masks = {"full_attention": full}
        if any(t == "sliding_attention" for t in self.config.layer_types):
            masks["sliding_attention"] = _make_sliding_causal_mask(
                B, S, self.config.sliding_window, attention_mask, dtype, device
            )
        return masks

    def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
                capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
        B, S = input_ids.shape
        device = input_ids.device
        dtype = self.dtype

        hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)

        position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
        freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)

        attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)

        capture_layers = list(capture_layers) if capture_layers else None
        if capture_layers:
            max_layer = max(capture_layers)
            wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
            captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
        else:
            max_layer = self.config.num_hidden_layers - 1
            wanted = None
            captured = None

        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, attn_masks, freqs_cis)
            if wanted is not None and i in wanted:
                captured[wanted[i]] = hidden_states
            if i >= max_layer:
                break

        if captured is not None:
            return {"hidden_states": captured}
        return {"last_hidden_state": self.norm(hidden_states)}


# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
    "Describe the image by detailing the color, shape, size, texture, "
    "quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512


# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"


def _lens_render_chat(prompt: str) -> str:
    """Render the Lens prompt in GPT-OSS Harmony format."""
    return (
        f"<|start|>system<|message|>"
        f"You are ChatGPT, a large language model trained by OpenAI.\n"
        f"Knowledge cutoff: 2024-06\n"
        f"Current date: {_LENS_CHAT_DATE}\n\n"
        f"Reasoning: medium\n\n"
        f"# Valid channels: analysis, commentary, final. "
        f"Channel must be included for every message.<|end|>"
        f"<|start|>developer<|message|># Instructions\n\n"
        f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
        f"<|start|>user<|message|>{prompt}<|end|>"
        f"<|start|>assistant<|channel|>analysis<|message|>"
        f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
        f"<|start|>assistant<|channel|>final<|message|>"
    )


# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999  # <|endoftext|>


class _GptOssRawTokenizer:
    """Raw ``tokenizers.Tokenizer`` wrapper.

    The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
    (``tokenizer_json`` key) rather than as a committed file. Extracted
    it in ``sd.py`` and passes it here via ``tokenizer_data``.
    """

    def __init__(self, tokenizer_json_bytes=None, **kwargs):
        from tokenizers import Tokenizer
        if isinstance(tokenizer_json_bytes, torch.Tensor):
            tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
        if tokenizer_json_bytes is None:
            raise ValueError(
                "Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
                "encoder state dict. Re-bundle the encoder via bundle_te.py so it "
                "embeds the tokenizer."
            )
        self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))

    @classmethod
    def from_pretrained(cls, tokenizer_data, **kwargs):
        return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)

    def __call__(self, text):
        return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}

    def get_vocab(self):
        return self.tokenizer.get_vocab()

    def convert_tokens_to_ids(self, tokens):
        return [self.tokenizer.token_to_id(t) for t in tokens]

    def decode(self, ids, **kwargs):
        return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))


class LensGptOssTokenizer(sd1_clip.SDTokenizer):
    tokenizer_json_data = None

    def __init__(self, embedding_directory=None, tokenizer_data={}):
        tokenizer_json = tokenizer_data.get("tokenizer_json", None)
        self.tokenizer_json_data = tokenizer_json
        super().__init__(
            tokenizer_json,
            embedding_directory=embedding_directory,
            pad_with_end=False,
            embedding_size=2880,
            embedding_key="gpt_oss",
            tokenizer_class=_GptOssRawTokenizer,
            has_start_token=False,
            has_end_token=False,
            pad_to_max_length=False,
            max_length=99999999,
            min_length=1,
            pad_left=False,
            disable_weights=True,
            tokenizer_data=tokenizer_data,
        )
        self.pad_token_id = _LENS_PAD_TOKEN_ID

    def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
        # Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
        if not text or not text.strip():
            return [[]]
        rendered = _lens_render_chat(text)
        ids = self.tokenizer(rendered)["input_ids"]
        if len(ids) > LENS_MAX_TOKENS:
            ids = ids[:LENS_MAX_TOKENS]
        return [[(int(t), 1.0) for t in ids]]

    def state_dict(self):
        if self.tokenizer_json_data is not None:
            return {"tokenizer_json": self.tokenizer_json_data}
        return {}


class LensTokenizer(sd1_clip.SD1Tokenizer):
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        super().__init__(
            embedding_directory=embedding_directory,
            tokenizer_data=tokenizer_data,
            name="gpt_oss",
            tokenizer=LensGptOssTokenizer,
        )


class LensGptOssClipModel(nn.Module):
    """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""

    def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
        super().__init__()
        model_options = dict(model_options or {})

        operations = model_options.get("custom_operations")
        if operations is None:
            quant_config = model_options.get("quantization_metadata") or {}
            operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
        self.operations = operations

        cfg_overrides = model_options.get("gpt_oss_config", {})
        self.config = GptOss20BConfig(**cfg_overrides)
        self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
        self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))

        self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
        self.num_layers = self.config.num_hidden_layers
        self.dtype = dtype
        self.execution_device = None
        self._pad_token_id = _LENS_PAD_TOKEN_ID

    def set_clip_options(self, options):
        self.execution_device = options.get("execution_device", self.execution_device)

    def reset_clip_options(self):
        self.execution_device = None

    def _gather_tokens(self, token_weight_pairs):
        ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
        pad_id = self._pad_token_id
        max_len = max(len(x) for x in ids_list)
        device = self.execution_device
        ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
        mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
        for i, x in enumerate(ids_list):
            ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
            mask[i, : len(x)] = 1
        return ids, mask

    def encode_token_weights(self, token_weight_pairs):
        # Empty negative: emit zero-length features + zero mask
        if all(len(batch) == 0 for batch in token_weight_pairs):
            device = self.execution_device
            B = len(token_weight_pairs)
            L = len(self.selected_layers)
            H = self.config.hidden_size
            flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
            mask = torch.zeros(B, 0, dtype=torch.long, device=device)
            return flat, None, {"attention_mask": mask, "num_layers_stacked": L}

        input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
        out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
        layers = out["hidden_states"]  # list of L × [B, S, H]
        stacked = torch.stack(layers, dim=2)  # [B, S, L, H]

        offset = self.txt_offset
        if stacked.shape[1] > offset:
            stacked = stacked[:, offset:].contiguous()
            mask_trim = attn_mask[:, offset:]
        else:
            stacked = stacked[:, :0]
            mask_trim = attn_mask[:, :0]

        B, S, L, H = stacked.shape
        flat = stacked.reshape(B, S, L * H)
        extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
        return flat, None, extra

    def load_sd(self, sd):
        return self.transformer.load_state_dict(sd, strict=False, assign=True)


class LensTEModel(sd1_clip.SD1ClipModel):
    def __init__(self, device="cpu", dtype=None, model_options=None):
        super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})


def lens_te(dtype_llama=None, llama_quantization_metadata=None):
    class LensTEModel_(LensTEModel):
        def __init__(self, device="cpu", dtype=None, model_options=None):
            mo = dict(model_options or {})
            if llama_quantization_metadata is not None:
                mo["quantization_metadata"] = llama_quantization_metadata
            if dtype is None and dtype_llama is not None:
                dtype = dtype_llama
            super().__init__(device=device, dtype=dtype, model_options=mo)

    return LensTEModel_
