class GTConv2d(nn.Module):
    def __init__(
        self,
        node_in_dim: int,
        hidden_dim: int,
        edge_in_dim: Optional[int] = None,
        num_heads: int = 8,
        gate=False,
        qkv_bias=False,
        dropout: float = 0.0,
        norm: str = "bn",
        act: str = "relu",
        aggregators: List[str] = ["sum"],
    ):
        super().__init__()
        # 基本参数配置
        assert "sum" in aggregators
        assert hidden_dim % num_heads == 0
        self.aggregators = aggregators
        self.num_aggrs = len(aggregators)
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.edge_in_dim = edge_in_dim
        self.gate = gate

        # 节点特征转换层
        self.WQ = nn.Linear(node_in_dim, hidden_dim, bias=qkv_bias)
        self.WK = nn.Linear(node_in_dim, hidden_dim, bias=qkv_bias)
        self.WV = nn.Linear(node_in_dim, hidden_dim, bias=qkv_bias)
        self.WO = nn.Linear(hidden_dim * self.num_aggrs, node_in_dim, bias=True)

        # 边特征处理
        if edge_in_dim is not None:
            self.WE = nn.Linear(edge_in_dim, hidden_dim, bias=True)
            self.WOe = nn.Linear(hidden_dim, edge_in_dim, bias=True)
            self.ffn_e = MLP(
                edge_in_dim, edge_in_dim, hidden_dim, 1, dropout, act
            )
            # 边标准化层
            norm_class = nn.BatchNorm1d if norm == "bn" else nn.LayerNorm
            self.norm1e = norm_class(edge_in_dim)
            self.norm2e = norm_class(edge_in_dim)
            if gate:
                self.e_gate = nn.Linear(edge_in_dim, hidden_dim, bias=True)
        else:
            self.WE = self.e_gate = None

        # 节点标准化层
        norm_class = nn.BatchNorm1d if norm == "bn" else nn.LayerNorm
        self.norm1 = norm_class(node_in_dim)
        self.norm2 = norm_class(node_in_dim)

        # 门控机制
        if gate:
            self.n_gate = nn.Linear(node_in_dim, hidden_dim, bias=True)
        
        # 其他组件
        self.dropout = nn.Dropout(dropout)
        self.ffn = MLP(
            node_in_dim, node_in_dim, hidden_dim, 1, dropout, act
        )

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ.weight)
        nn.init.xavier_uniform_(self.WK.weight)
        nn.init.xavier_uniform_(self.WV.weight)
        nn.init.xavier_uniform_(self.WO.weight)
        if self.edge_in_dim is not None:
            nn.init.xavier_uniform_(self.WE.weight)
            nn.init.xavier_uniform_(self.WOe.weight)

    def forward(self, x, edge_index, edge_attr=None):
        # 原始特征保留
        x_res = x
        e_res = edge_attr

        # 1. 计算Q/K/V
        Q = self.WQ(x).view(-1, self.num_heads, self.head_dim)
        K = self.WK(x).view(-1, self.num_heads, self.head_dim)
        V = self.WV(x).view(-1, self.num_heads, self.head_dim)
        G = self.n_gate(x).view(-1, self.num_heads, self.head_dim) if self.gate else 1

        # 2. 手动消息传递
        src, dst = edge_index
        Q_i = Q[dst]  # 目标节点的Q
        K_j = K[src]  # 源节点的K
        V_j = V[src]  # 源节点的V

        # 3. 注意力计算
        d_k = Q_i.size(-1)
        attn_logits = (Q_i * K_j).sum(-1) / math.sqrt(d_k)
        
        # 4. 边特征融合
        if self.edge_in_dim is not None:
            E = self.WE(edge_attr).view(-1, self.num_heads, self.head_dim)
            attn_logits = E * attn_logits.unsqueeze(-1)
            edge_update = attn_logits

        # 5. 门控机制
        if self.gate and self.edge_in_dim is not None:
            e_gate = self.e_gate(edge_attr).view(-1, self.num_heads, self.head_dim)
            attn_logits = attn_logits * torch.sigmoid(e_gate)

        # 6. 注意力权重
        alpha = torch.softmax(attn_logits, dim=1)
        
        # 7. 消息聚合
        V_j_gated = V_j * torch.sigmoid(G[dst]) if self.gate else V_j
        messages = alpha.unsqueeze(-1) * V_j_gated

        # 8. 多聚合器处理
        aggregated = []
        for aggr in self.aggregators:
            if aggr == 'sum':
                out = scatter(messages, dst, dim=0, reduce='sum')
            elif aggr == 'mean':
                out = scatter(messages, dst, dim=0, reduce='mean')
            # 可扩展其他聚合器
            aggregated.append(out)
        out = torch.cat(aggregated, dim=-1)

        # 9. 节点特征更新
        out = out.view(-1, self.num_heads * self.head_dim * self.num_aggrs)
        out = self.dropout(out)
        out = self.WO(out) + x_res
        out = self.norm1(out)
        
        # 10. FFN
        out = self.ffn(out) + out
        out = self.norm2(out)

        # 11. 边特征更新
        if self.edge_in_dim is not None:
            edge_out = self.WOe(edge_update.view(-1, self.head_dim)) + e_res
            edge_out = self.norm1e(edge_out)
            edge_out = self.ffn_e(edge_out) + edge_out
            edge_out = self.norm2e(edge_out)
            return out, edge_out
        return out, None