Commit 2f6fd7c7 authored by p x's avatar p x
Browse files

根据前端的type 传参推送不同的数据

parent 3d2834bf
...@@ -39,7 +39,7 @@ dependencies { ...@@ -39,7 +39,7 @@ dependencies {
// https://mvnrepository.com/artifact/com.alibaba.fastjson2/fastjson2 // https://mvnrepository.com/artifact/com.alibaba.fastjson2/fastjson2
implementation("com.alibaba.fastjson2:fastjson2:2.0.60") implementation("com.alibaba.fastjson2:fastjson2:2.0.60")
// https://mvnrepository.com/artifact/io.netty/netty-all // https://mvnrepository.com/artifact/io.netty/netty-all
implementation("io.netty:netty-all:4.2.9.Final") implementation("io.netty:netty-all:4.1.92.Final")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1")
} }
......
This diff is collapsed.
package com.inzy.wsmock
import io.netty.channel.Channel
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.util.AttributeKey
import java.util.concurrent.ConcurrentHashMap
/**
* Channel 全局管理类(单例)
* 统一管理在线客户端Channel,提供增删查、分组等操作
*/
class ChannelManager private constructor() {
// 存储在线客户端Channel(线程安全)
private val onlineChannels = ConcurrentHashMap<String, Channel>()
/**
* 添加Channel(客户端连接时调用)
*/
fun addChannel(channel: Channel) {
val clientId = channel.id().asShortText()
onlineChannels[clientId] = channel
println("Channel添加成功:$clientId,当前在线数:${onlineChannels.size}")
}
/**
* 移除Channel(客户端断开/异常时调用)
*/
fun removeChannel(channel: Channel) {
val clientId = channel.id().asShortText()
onlineChannels.remove(clientId)
println("Channel移除成功:$clientId,当前在线数:${onlineChannels.size}")
}
/**
* 根据客户端ID获取Channel
*/
fun getChannel(clientId: String): Channel? {
return onlineChannels[clientId]
}
/**
* 获取所有在线Channel
*/
fun getAllChannels(): ConcurrentHashMap<String, Channel> {
return onlineChannels // 返回原对象(ConcurrentHashMap线程安全),或返回副本:ConcurrentHashMap(onlineChannels)
}
/**
* 根据Channel的attr属性(CLIENT_TYPE_KEY)分组
*/
fun groupChannelByType(typeKey: AttributeKey<String>): Map<String, List<Channel>> {
val groupMap = mutableMapOf<String, MutableList<Channel>>()
onlineChannels.forEach { (_, channel) ->
// 获取Channel的type属性,默认值为"default"
val type = channel.attr(typeKey).get() ?: "default"
if (!groupMap.containsKey(type)) {
groupMap[type] = mutableListOf()
}
// 仅保留活跃的Channel
if (channel.isActive) {
groupMap[type]?.add(channel)
}
}
return groupMap
}
/***发送过滤好的通道**/
fun sendMsgFromType(typeChannels: Map<String, Channel>,msg: String) {
if (typeChannels.isEmpty()) {
// println("无在线客户端,跳过推送")
return
}
val frame = TextWebSocketFrame(msg)
typeChannels.forEach { (clientId, channel) ->
if (channel.isActive) {
channel.writeAndFlush(frame)
.addListener { future ->
if (!future.isSuccess) {
println("推送消息给客户端[$clientId]失败:${future.cause()?.message}")
removeChannel(channel) // 推送失败移除失效Channel
}
}
} else {
removeChannel(channel) // 移除非活跃Channel
}
}
}
/**
* 广播消息给所有在线客户端
*/
/* fun broadcastMsg(msg: String) {
if (onlineChannels.isEmpty()) {
println("无在线客户端,跳过推送")
return
}
val frame = TextWebSocketFrame(msg)
onlineChannels.forEach { (clientId, channel) ->
if (channel.isActive) {
channel.writeAndFlush(frame)
.addListener { future ->
if (!future.isSuccess) {
println("推送消息给客户端[$clientId]失败:${future.cause()?.message}")
removeChannel(channel) // 推送失败移除失效Channel
}
}
} else {
removeChannel(channel) // 移除非活跃Channel
}
}
println("已向${onlineChannels.size}个客户端推送消息:$msg")
}*/
/**
* 单例模式(饿汉式)
*/
companion object {
val instance: ChannelManager = ChannelManager()
}
}
\ No newline at end of file
...@@ -16,17 +16,25 @@ import jakarta.annotation.PreDestroy ...@@ -16,17 +16,25 @@ import jakarta.annotation.PreDestroy
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import lombok.extern.slf4j.Slf4j
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Value import org.springframework.beans.factory.annotation.Value
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
import java.net.BindException
import java.net.InetSocketAddress
import java.net.ServerSocket
@Slf4j
@Component @Component
class NettyWebSocketServer( class NettyWebSocketServer(
@Value("\${netty.websocket.port:8089}") private val port: Int, @Value("\${netty.websocket.port:8089}") private val port: Int
private val webSocketHandler: WebSocketHandler // private val webSocketHandler: WebSocketHandler
) { ) {
private val logger = LoggerFactory.getLogger(javaClass) private val logger = LoggerFactory.getLogger(javaClass)
private val websocketPath="/gs-guide-websocket"
// Netty主从线程组 // Netty主从线程组
private val bossGroup = NioEventLoopGroup(1) private val bossGroup = NioEventLoopGroup(1)
private val workerGroup = NioEventLoopGroup() private val workerGroup = NioEventLoopGroup()
...@@ -53,10 +61,28 @@ class NettyWebSocketServer( ...@@ -53,10 +61,28 @@ class NettyWebSocketServer(
pipeline.addLast(HttpServerCodec()) pipeline.addLast(HttpServerCodec())
pipeline.addLast(ChunkedWriteHandler()) pipeline.addLast(ChunkedWriteHandler())
pipeline.addLast(HttpObjectAggregator(1024 * 1024)) pipeline.addLast(HttpObjectAggregator(1024 * 1024))
pipeline.addLast(WebSocketServerProtocolHandler("/gs-guide-websocket")) pipeline.addLast(RequestParamHandler()) // 添加HTTP请求拦截器
pipeline.addLast(webSocketHandler) pipeline.addLast(WebSocketServerProtocolHandler(websocketPath))
pipeline.addLast(WebSocketHandler())
// 调试:打印处理器链顺序,确认RequestParamHandler在正确位置
// logger.info("ChannelPipeline顺序:${pipeline.map { it.javaClass.simpleName }}")
} }
}) })
// 检测端口是否被占用
val isPortAvailable = InetSocketAddress(port).let {
ServerSocket().use { socket ->
try {
socket.bind(it)
true
} catch (e: BindException) {
false
}
}
}
if (!isPortAvailable) {
logger.error("端口$port 已被占用,启动失败!")
return@launch
}
// 绑定端口并启动(非阻塞?不,这里sync()只阻塞当前新线程,不阻塞Spring主线程) // 绑定端口并启动(非阻塞?不,这里sync()只阻塞当前新线程,不阻塞Spring主线程)
serverChannelFuture = bootstrap.bind(port).sync() serverChannelFuture = bootstrap.bind(port).sync()
logger.info("Netty WebSocket服务启动成功,端口:$port") logger.info("Netty WebSocket服务启动成功,端口:$port")
......
package com.inzy.wsmock
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.FullHttpRequest
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.util.AttributeKey
class RequestParamHandler : SimpleChannelInboundHandler<TextWebSocketFrame?>() {
// 定义存储查询参数的AttributeKey(全局复用)
companion object {
val REQUEST_PARAMS_KEY = AttributeKey.valueOf<Map<String, String>>("REQUEST_PARAMS")
val REQUEST_PATH_KEY = AttributeKey.valueOf<String>("REQUEST_PATH") // 存储请求路径
val PARAM_TYPE_KEY = "type"
val PARAM_TYPE_VALUE_1 = "1"
val PARAM_TYPE_VALUE_2 = "2"
}
@Throws(Exception::class)
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
if (msg is FullHttpRequest) {
val request = msg
val uri = request.uri()
// 1. 解析请求路径和查询参数
val (path, params) = parseUri(uri)
// 2. 打印调试信息
// println("请求路径: $path")
// println("查询参数: $params")
// 3. 将路径和参数存入Channel属性(供后续处理器使用)
ctx.channel().attr(REQUEST_PATH_KEY).set(path)
ctx.channel().attr(REQUEST_PARAMS_KEY).set(params)
// 4. 可选:去除查询参数后更新请求URI(不影响后续处理器解析路径)
request.uri = path
// 去除查询参数后更新URI(可选)
/* if (uri.contains("?")) {
val newUri = uri.substring(0, uri.indexOf("?"))
request.setUri(newUri)
}*/
}
// 处理WebSocket消息
super.channelRead(ctx, msg)
}
@Throws(Exception::class)
override fun channelRead0(ctx: ChannelHandlerContext?, msg: TextWebSocketFrame?) {
}
// 核心方法:解析URI,返回 Pair(路径, 参数字典)
private fun parseUri(uri: String): Pair<String, Map<String, String>> {
// 分割路径和参数(? 是分隔符)
val splitIndex = uri.indexOf("?")
return if (splitIndex == -1) {
// 无查询参数的情况
Pair(uri, emptyMap())
} else {
// 有查询参数的情况:解析路径 + 解析参数
val path = uri.substring(0, splitIndex)
val paramStr = uri.substring(splitIndex + 1)
val params = parseParams(paramStr)
Pair(path, params)
}
}
// 解析查询参数字符串(a=1&b=2 形式)
private fun parseParams(paramStr: String): Map<String, String> {
val params = mutableMapOf<String, String>()
if (paramStr.isBlank()) return params
// 按 & 分割多个参数
val paramArray = paramStr.split("&").filter { it.isNotBlank() }
for (param in paramArray) {
// 按 = 分割参数名和值
val keyValue = param.split("=", limit = 2) // limit=2 避免值包含=的情况
if (keyValue.size == 2) {
val key = keyValue[0].trim()
val value = keyValue[1].trim()
params[key] = value
}
}
return params
}
// 自定义方法:解析URL参数
/* private fun getUrlParams(url: String): MutableMap<String?, String?> {
var url = url
val map: MutableMap<String?, String?> = HashMap<String?, String?>()
url = url.replace("?", ";")
if (!url.contains(";")) return map
val arr = url.split(";".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[1].split("&".toRegex())
.dropLastWhile { it.isEmpty() }.toTypedArray()
for (s in arr) {
val pair = s.split("=".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
if (pair.size == 2) {
map.put(pair[0], pair[1])
}
}
return map
}*/
}
package com.inzy.wsmock package com.inzy.wsmock
import com.alibaba.fastjson2.JSONObject import com.alibaba.fastjson2.JSONObject
import com.inzy.wsmock.RequestParamHandler.Companion.PARAM_TYPE_KEY
import com.inzy.wsmock.RequestParamHandler.Companion.PARAM_TYPE_VALUE_1
import com.inzy.wsmock.RequestParamHandler.Companion.PARAM_TYPE_VALUE_2
import com.inzy.wsmock.RequestParamHandler.Companion.REQUEST_PARAMS_KEY
import io.netty.channel.Channel
import lombok.extern.slf4j.Slf4j import lombok.extern.slf4j.Slf4j
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.springframework.scheduling.annotation.Scheduled import org.springframework.scheduling.annotation.Scheduled
import org.springframework.stereotype.Component import org.springframework.stereotype.Component
import java.time.LocalDateTime import java.time.LocalDateTime
@Slf4j
@Component @Component
class ScheduledPushTask( class ScheduledPushTask(
private val webSocketHandler: WebSocketHandler, // private val webSocketHandler: WebSocketHandler,
private val pushConfig: PushConfig // private val pushConfig: PushConfig
) { ) {
private val logger = LoggerFactory.getLogger(javaClass) private val logger = LoggerFactory.getLogger(javaClass)
// 注入ChannelManager单例
private val channelManager = ChannelManager.instance
/** /**
* 定时推送任务(固定延迟,避免任务叠加 * 定时推送任务(type=1
*/ */
@Scheduled(fixedDelayString = "#{@pushConfig.pushInterval}") @Scheduled(fixedDelayString = "#{@pushConfig.pushInterval}")
fun pushMsgToType1() {
// val onlineChannels = channelManager.getAllChannels()
//得到设置了type属性的channel
val typeChannels=filterTypeChannels(PARAM_TYPE_VALUE_1)
// println("onlineChannels.size = ${onlineChannels.size} typeChannels.size = ${typeChannels.size}")
val msgObj = JSONObject()
msgObj["content"] = "定时推送消息 type=1 ${LocalDateTime.now()}"
channelManager.sendMsgFromType(typeChannels,msgObj.toJSONString())
}
@Scheduled(fixedDelayString = "200")
fun pushMsgToType2() {
//得到设置了type属性的channel
val typeChannels=filterTypeChannels(PARAM_TYPE_VALUE_2)
// println("onlineChannels.size = ${onlineChannels.size} typeChannels.size = ${typeChannels.size}")
val msgObj = JSONObject()
msgObj["content"] = "定时推送消息 type=2 ${LocalDateTime.now()}"
channelManager.sendMsgFromType(typeChannels,msgObj.toJSONString())
}
/**
* @param type 前端的查询参数
* **/
private fun filterTypeChannels(type: String): Map<String, Channel> {
val onlineChannels = channelManager.getAllChannels()
//得到设置了type属性的channel
val typeChannels =
onlineChannels.filter { (id, channel) ->
if (channel.hasAttr(REQUEST_PARAMS_KEY)) {
val params = channel.attr(REQUEST_PARAMS_KEY).get()
return@filter params.get(PARAM_TYPE_KEY) == type
}
false
}
return typeChannels
}
/**
* 定时推送任务(固定延迟,避免任务叠加)
*/
// @Scheduled(fixedDelayString = "#{@pushConfig.pushInterval}")
fun pushMsgToClients() { fun pushMsgToClients() {
// 增加日志,确认函数是否执行 // 增加日志,确认函数是否执行
// logger.info("定时推送任务开始执行 - ${LocalDateTime.now()}") // logger.info("定时推送任务开始执行 - ${LocalDateTime.now()}")
...@@ -28,11 +80,22 @@ class ScheduledPushTask( ...@@ -28,11 +80,22 @@ class ScheduledPushTask(
// if (!pushConfig.pushEnabled.get()) { // if (!pushConfig.pushEnabled.get()) {
// return // return
// } // }
val onlineChannels = channelManager.getAllChannels()
// if (onlineChannels.isEmpty()) {
//// println("定时推送:无在线客户端,跳过")
// return
// }
//得到设置了type属性的channel
// val typeChannels =
// onlineChannels.filter { (id, channel) -> channel.hasAttr(WebSocketHandler.WS_QUERY_PARAMS_KEY) }
// println("onlineChannels.size = ${onlineChannels.size} typeChannels.size = ${typeChannels.size}")
// 构造推送消息(适配前端格式) // 构造推送消息(适配前端格式)
val msgObj = JSONObject() val msgObj = JSONObject()
msgObj["content"] = "定时推送消息 - ${LocalDateTime.now()}" msgObj["content"] = "定时推送消息 - ${LocalDateTime.now()}"
logger.debug("msgObj = ${msgObj}") logger.debug("msgObj = ${msgObj}")
// 广播给所有客户端 // 广播给所有客户端
webSocketHandler.broadcastMsg(msgObj.toJSONString()) // webSocketHandler.broadcastMsg(msgObj.toJSONString())
} }
} }
\ No newline at end of file
package com.inzy.wsmock package com.inzy.wsmock
import io.netty.channel.Channel
import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.websocketx.WebSocketFrame
import org.springframework.stereotype.Component import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler
import java.util.concurrent.ConcurrentHashMap import org.slf4j.LoggerFactory
/** /**
* 自定义WebSocket处理器(支持多客户端 * 自定义WebSocket处理器(专注处理消息交互,Channel管理交给ChannelManager
*/ */
@Sharable @Sharable
@Component //@Component
class WebSocketHandler : SimpleChannelInboundHandler<WebSocketFrame>() { class WebSocketHandler : SimpleChannelInboundHandler<WebSocketFrame>() {
// 存储在线客户端Channel(线程安全)
private val onlineChannels = ConcurrentHashMap<String, Channel>()
// private val logger = LoggerFactory.getLogger(javaClass)
// 注入ChannelManager单例
private val channelManager = ChannelManager.instance
// 监听握手成功事件(此时RequestParamHandler已解析参数)
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (evt is WebSocketServerProtocolHandler.HandshakeComplete) {
val channel = ctx.channel()
// 交给ChannelManager管理
channelManager.addChannel(channel)
// 握手成功后读取参数(此时参数已存储)
// val params = ctx.channel().attr(RequestParamHandler.REQUEST_PARAMS_KEY).get()
// val path = ctx.channel().attr(RequestParamHandler.REQUEST_PATH_KEY).get()
// logger.info("WS握手成功,客户端查询参数:$params 请求路径:$path")
ctx.writeAndFlush(TextWebSocketFrame("Welcome!"))
// ctx.writeAndFlush(TextWebSocketFrame("Welcome! Client ID: $clientId, params: $params"))
}
super.userEventTriggered(ctx, evt)
}
/** /**
* 客户端连接成功 * 客户端连接成功
*/ */
override fun channelActive(ctx: ChannelHandlerContext) { /* override fun channelActive(ctx: ChannelHandlerContext) {
val channel = ctx.channel() val channel = ctx.channel()
val clientId = channel.id().asShortText() val clientId = channel.id().asShortText()
onlineChannels[clientId] = channel // 交给ChannelManager管理
println("客户端连接成功:$clientId,当前在线数:${onlineChannels.size}") channelManager.addChannel(channel)
// val type = ctx.channel().attr(RequestParamHandler.REQUEST_PARAMS_KEY).get()
// println("当前连接的客户端type:$type")
// 欢迎消息 // 欢迎消息
channel.writeAndFlush(TextWebSocketFrame("Welcome! Client ID: $clientId")) channel.writeAndFlush(TextWebSocketFrame("Welcome! Client ID: $clientId"))
} }*/
/** /**
* 客户端断开连接 * 客户端断开连接
*/ */
override fun channelInactive(ctx: ChannelHandlerContext) { override fun channelInactive(ctx: ChannelHandlerContext) {
val channel = ctx.channel() val channel = ctx.channel()
val clientId = channel.id().asShortText() // 交给ChannelManager管理
onlineChannels.remove(clientId) channelManager.removeChannel(channel)
println("客户端断开连接:$clientId,当前在线数:${onlineChannels.size}")
} }
/** /**
* 处理客户端消息 * 处理客户端消息
*/ */
override fun channelRead0(ctx: ChannelHandlerContext, frame: WebSocketFrame) { override fun channelRead0(ctx: ChannelHandlerContext, frame: WebSocketFrame) {
if (frame is TextWebSocketFrame) { /* if (frame is TextWebSocketFrame) {
val msg = frame.text() val msg = frame.text()
val clientId = ctx.channel().id().asShortText() val clientId = ctx.channel().id().asShortText()
println("收到客户端[$clientId]消息:$msg") println("收到客户端[$clientId]消息:$msg")
// 回复客户端 // 回复客户端
ctx.writeAndFlush(TextWebSocketFrame("Server received: $msg")) ctx.writeAndFlush(TextWebSocketFrame("Server received: $msg"))
} }*/
} }
/** /**
* 异常处理 * 异常处理
*/ */
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
val clientId = ctx.channel().id().asShortText() val channel = ctx.channel()
val clientId = channel.id().asShortText()
println("客户端[$clientId]发生异常:${cause.message}") println("客户端[$clientId]发生异常:${cause.message}")
ctx.close() ctx.close()
onlineChannels.remove(clientId) // 交给ChannelManager移除失效Channel
channelManager.removeChannel(channel)
} }
/** /**
* 广播消息给所有在线客户端 * 广播消息(复用ChannelManager的实现)
*/ */
fun broadcastMsg(msg: String) { /* fun broadcastMsg(msg: String) {
if (onlineChannels.isEmpty()) { channelManager.broadcastMsg(msg)
println("无在线客户端,跳过推送") }*/
return
}
val frame = TextWebSocketFrame(msg)
onlineChannels.forEach { (clientId, channel) ->
if (channel.isActive) {
channel.writeAndFlush(frame)
.addListener { future ->
if (!future.isSuccess) {
println("推送消息给客户端[$clientId]失败:${future.cause()?.message}")
}
}
} else {
onlineChannels.remove(clientId)
}
}
println("已向${onlineChannels.size}个客户端推送消息:$msg")
}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment