import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device


class Dino2AttentionOutput(torch.nn.Module):
    def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
        super().__init__()
        self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)

    def forward(self, x):
        return self.dense(x)


class Dino2AttentionBlock(torch.nn.Module):
    def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
        super().__init__()
        self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
        self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)

    def forward(self, x, mask, optimized_attention):
        return self.output(self.attention(x, mask, optimized_attention))


class LayerScale(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))

    def forward(self, x):
        return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)

class Dinov2MLP(torch.nn.Module):
    def __init__(self, hidden_size: int, dtype, device, operations):
        super().__init__()

        mlp_ratio = 4
        hidden_features = int(hidden_size * mlp_ratio)
        self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
        self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        hidden_state = self.fc1(hidden_state)
        hidden_state = torch.nn.functional.gelu(hidden_state)
        hidden_state = self.fc2(hidden_state)
        return hidden_state

class SwiGLUFFN(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        in_features = out_features = dim
        hidden_features = int(dim * 4)
        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8

        self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
        self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)

    def forward(self, x):
        x = self.weights_in(x)
        x1, x2 = x.chunk(2, dim=-1)
        x = torch.nn.functional.silu(x1) * x2
        return self.weights_out(x)


class Dino2Block(torch.nn.Module):
    def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
        super().__init__()
        self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
        self.layer_scale1 = LayerScale(dim, dtype, device, operations)
        self.layer_scale2 = LayerScale(dim, dtype, device, operations)
        if use_swiglu_ffn:
            self.mlp = SwiGLUFFN(dim, dtype, device, operations)
        else:
            self.mlp = Dinov2MLP(dim, dtype, device, operations)
        self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
        self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

    def forward(self, x, optimized_attention):
        x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
        x = x + self.layer_scale2(self.mlp(self.norm2(x)))
        return x


class Dino2Encoder(torch.nn.Module):
    def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
        super().__init__()
        self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
                                          for _ in range(num_layers)])

    def forward(self, x, intermediate_output=None):
        optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)

        if intermediate_output is not None:
            if intermediate_output < 0:
                intermediate_output = len(self.layer) + intermediate_output

        intermediate = None
        for i, layer in enumerate(self.layer):
            x = layer(x, optimized_attention)
            if i == intermediate_output:
                intermediate = x.clone()
        return x, intermediate


class Dino2PatchEmbeddings(torch.nn.Module):
    def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
        super().__init__()
        self.patch_size = patch_size
        self.projection = operations.Conv2d(
            in_channels=num_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=True,
            dtype=dtype,
            device=device
        )

    def forward(self, pixel_values):
        return self.projection(pixel_values).flatten(2).transpose(1, 2)


class Dino2Embeddings(torch.nn.Module):
    def __init__(self, dim, dtype, device, operations):
        super().__init__()
        patch_size = 14
        image_size = 518
        self.patch_size = patch_size

        self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
        self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
        self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
        self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))

    def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
        pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)

        class_pos = pos_embed[:, 0:1]
        patch_pos = pos_embed[:, 1:]
        N = patch_pos.shape[1]
        M = int(N ** 0.5)
        h0 = h_pixels // self.patch_size
        w0 = w_pixels // self.patch_size
        scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)  # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).

        patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
        patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
        patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
        return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)

    def forward(self, pixel_values):
        x = self.patch_embeddings(pixel_values)
        x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
        if x.shape[1] - 1 == self.position_embeddings.shape[1] - 1:
            x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
        else:
            h, w = pixel_values.shape[-2:]
            x = x + self.interpolate_pos_encoding(x, h, w)
        return x


class Dinov2Model(torch.nn.Module):
    def __init__(self, config_dict, dtype, device, operations):
        super().__init__()
        num_layers = config_dict["num_hidden_layers"]
        dim = config_dict["hidden_size"]
        heads = config_dict["num_attention_heads"]
        layer_norm_eps = config_dict["layer_norm_eps"]
        use_swiglu_ffn = config_dict["use_swiglu_ffn"]

        self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
        self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
        self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)

    def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
        x = self.embeddings(pixel_values)
        x, i = self.encoder(x, intermediate_output=intermediate_output)
        x = self.layernorm(x)
        pooled_output = x[:, 0, :]
        return x, i, pooled_output, None

    def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
        x = self.embeddings(pixel_values)
        optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
        n_layers = len(self.encoder.layer)
        resolved = [(i if i >= 0 else n_layers + i) for i in indices]
        target = set(resolved)
        max_idx = max(resolved)
        n_skip = 1  # skip cls token
        cache = {}
        for i, layer in enumerate(self.encoder.layer):
            x = layer(x, optimized_attention)
            if i in target:
                normed = self.layernorm(x) if apply_norm else x
                cache[i] = (normed[:, n_skip:], normed[:, 0])
            if i >= max_idx:
                break
        return [cache[i] for i in resolved]
