package com.inzy.wsmock


import io.netty.channel.Channel
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketFrame
import org.springframework.stereotype.Component
import java.util.concurrent.ConcurrentHashMap

/**
 * 自定义WebSocket处理器（支持多客户端）
 */
@Sharable
@Component
class WebSocketHandler : SimpleChannelInboundHandler<WebSocketFrame>() {
    // 存储在线客户端Channel（线程安全）
    private val onlineChannels = ConcurrentHashMap<String, Channel>()

    /**
     * 客户端连接成功
     */
    override fun channelActive(ctx: ChannelHandlerContext) {
        val channel = ctx.channel()
        val clientId = channel.id().asShortText()
        onlineChannels[clientId] = channel
        println("客户端连接成功：$clientId，当前在线数：${onlineChannels.size}")
        // 欢迎消息
        channel.writeAndFlush(TextWebSocketFrame("Welcome! Client ID: $clientId"))
    }

    /**
     * 客户端断开连接
     */
    override fun channelInactive(ctx: ChannelHandlerContext) {
        val channel = ctx.channel()
        val clientId = channel.id().asShortText()
        onlineChannels.remove(clientId)
        println("客户端断开连接：$clientId，当前在线数：${onlineChannels.size}")
    }

    /**
     * 处理客户端消息
     */
    override fun channelRead0(ctx: ChannelHandlerContext, frame: WebSocketFrame) {
        if (frame is TextWebSocketFrame) {
            val msg = frame.text()
            val clientId = ctx.channel().id().asShortText()
            println("收到客户端[$clientId]消息：$msg")
            // 回复客户端
            ctx.writeAndFlush(TextWebSocketFrame("Server received: $msg"))
        }
    }

    /**
     * 异常处理
     */
    override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
        val clientId = ctx.channel().id().asShortText()
        println("客户端[$clientId]发生异常：${cause.message}")
        ctx.close()
        onlineChannels.remove(clientId)
    }

    /**
     * 广播消息给所有在线客户端
     */
    fun broadcastMsg(msg: String) {
        if (onlineChannels.isEmpty()) {
            println("无在线客户端，跳过推送")
            return
        }
        val frame = TextWebSocketFrame(msg)
        onlineChannels.forEach { (clientId, channel) ->
            if (channel.isActive) {
                channel.writeAndFlush(frame)
                    .addListener { future ->
                        if (!future.isSuccess) {
                            println("推送消息给客户端[$clientId]失败：${future.cause()?.message}")
                        }
                    }
            } else {
                onlineChannels.remove(clientId)
            }
        }
        println("已向${onlineChannels.size}个客户端推送消息：$msg")
    }
}