Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/支持多集群 #5

Merged
merged 4 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
1.2.16
1.支持多集群客户端,服务端
2.支持统计用户在系统精准的使用时间
  • Loading branch information
wangzihao committed Jun 3, 2024
commit 92a928648f1133836467ac355f09b601e3e01dae
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ sse协议的后端API, 比websocket轻量的实时通信, 支持集群,qos,
<dependency>
<groupId>com.github.wangzihaogithub</groupId>
<artifactId>sse-server</artifactId>
<version>1.2.15</version>
<version>1.2.16</version>
</dependency>

2. 配置业务逻辑 (后端)
Expand Down
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.wangzihaogithub</groupId>
<artifactId>sse-server</artifactId>
<version>1.2.15</version>
<version>1.2.16</version>
<name>sse-server</name>
<description>Sse server for Spring Boot</description>
<url>https://github.com/wangzihaogithub/sse-server.git</url>
Expand Down Expand Up @@ -117,7 +117,7 @@
<connection>scm:git:https://github.com/wangzihaogithub/sse-server.git</connection>
<developerConnection>scm:git:[email protected]:wangzihaogithub/sse-server.git</developerConnection>
<url>[email protected]:wangzihaogithub/sse-server.git</url>
<tag>v1.2.15</tag>
<tag>v1.2.16</tag>
</scm>

<!-- 开发者信息 -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public interface LocalConnectionService extends DistributedConnectionService, Co
<ACCESS_USER> List<SseEmitter<ACCESS_USER>> disconnectByAccessToken(String accessToken);

<ACCESS_USER> SseEmitter<ACCESS_USER> disconnectByConnectionId(Long connectionId);
<ACCESS_USER> SseEmitter<ACCESS_USER> disconnectByConnectionId(Long connectionId, Long duration, Long sessionDuration);

<ACCESS_USER> List<SseEmitter<ACCESS_USER>> disconnectByConnectionIds(Collection<Long> connectionIds);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,26 @@ public <ACCESS_USER> SseEmitter<ACCESS_USER> disconnectByConnectionId(Long conne
}
}

@Override
public <ACCESS_USER> SseEmitter<ACCESS_USER> disconnectByConnectionId(Long connectionId, Long duration, Long sessionDuration) {
SseEmitter<ACCESS_USER> sseEmitter = getConnectionById(connectionId);
if (sseEmitter != null) {
if (duration != null || sessionDuration != null) {
if (duration == null) {
duration = 0L;
}
if (sessionDuration == null) {
sessionDuration = 0L;
}
sseEmitter.setSessionDuration(duration + sessionDuration);
}
if (sseEmitter.disconnect()) {
return sseEmitter;
}
}
return null;
}

@Override
public <ACCESS_USER> List<SseEmitter<ACCESS_USER>> disconnectByConnectionIds(Collection<Long> connectionIds) {
if (connectionIds == null) {
Expand Down Expand Up @@ -806,4 +826,10 @@ public void setServerPort(@Value("${server.port:8080}") Integer serverPort) {
this.serverPort = serverPort;
}

@Override
public String toString() {
return "LocalConnectionServiceImpl{" +
beanName + "[" + connectionMap.size() + "]" +
'}';
}
}
132 changes: 71 additions & 61 deletions src/main/java/com/github/sseserver/local/LocalController.java

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions src/main/java/com/github/sseserver/local/SseEmitter.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class SseEmitter<ACCESS_USER> extends org.springframework.web.servlet.mvc
private String requestIp;
private String requestDomain;
private String userAgent;
private Long sessionDuration;
private Cookie[] httpCookies;
/**
* 前端已正在监听的钩子, 值是 {@link SseEventBuilder#name(String)}
Expand Down Expand Up @@ -221,6 +222,14 @@ public void setUserAgent(String userAgent) {
this.userAgent = userAgent;
}

public void setSessionDuration(Long sessionDuration) {
this.sessionDuration = sessionDuration;
}

public Long getSessionDuration() {
return sessionDuration;
}

public int getCount() {
return count;
}
Expand Down
151 changes: 79 additions & 72 deletions src/main/java/com/github/sseserver/local/SseWebController.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
Expand All @@ -26,9 +27,8 @@
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -66,21 +66,16 @@ public class SseWebController<ACCESS_USER> {
public static final String API_REPOSITORY_MESSAGES_JSON = "/connect/repositoryMessages.json";
public static final String API_USER_JSON = "/connect/users.json";

/**
* @deprecated v1.2.8
*/
@Deprecated
public static final String API_CONNECTIONS_JSON_V1 = "/connections";
public static final String API_CONNECTIONS_JSON = "/connect/connections.json";

private static final byte[] SSE_APPEND_BYTES = "\nexport default Sse".getBytes(Charset.forName("UTF-8"));
private final Logger logger = LoggerFactory.getLogger(getClass());
@Autowired(required = false)
protected HttpServletRequest request;
protected LocalConnectionService localConnectionService;
private final ClusterBatchDisconnectRunnable batchDisconnectRunnable = new ClusterBatchDisconnectRunnable(() -> localConnectionService != null ? localConnectionService.getCluster() : null);
@Value("${server.port:8080}")
private Integer serverPort;
private String sseServerIdHeaderName = "Sse-Server-Id";
private String sseServerIdHeaderName = "X-Sse-Server-Id";
private Integer clientIdMaxConnections = 3;
private Long keepaliveTime;
private boolean enableGetJson = false;
Expand Down Expand Up @@ -125,25 +120,34 @@ public void setSseServerIdHeaderName(String sseServerIdHeaderName) {
* 前端文件
*/
@RequestMapping("")
public Object index() {
return ssejs();
public Object index(@RequestParam(required = false, name = "script-type", defaultValue = "module") String type) throws IOException {
return ssejs(type);
}

/**
* 前端文件
*/
@RequestMapping("/sse.js")
public Object ssejs() {
public Object ssejs(@RequestParam(required = false, name = "script-type", defaultValue = "module") String type) throws IOException {
HttpHeaders headers = new HttpHeaders();

settingResponseHeader(headers);
headers.set("Content-Type", "application/javascript;charset=utf-8");
Resource body = readSseJs();
Resource body = readSseJs(type);
return new ResponseEntity<>(body, headers, HttpStatus.OK);
}

protected Resource readSseJs() {
protected Resource readSseJs(String type) throws IOException {
InputStream stream = SseWebController.class.getResourceAsStream("/sse.js");
return new InputStreamResource(stream);
if ("module".equalsIgnoreCase(type)) {
int bufferSize = Math.max(stream.available(), 4096);
ByteArrayOutputStream out = new ByteArrayOutputStream(bufferSize + SSE_APPEND_BYTES.length);
copy(stream, out, bufferSize);
out.write(SSE_APPEND_BYTES);
return new ByteArrayResource(out.toByteArray());
} else {
return new InputStreamResource(stream);
}
}

public void setLocalConnectionService(LocalConnectionService localConnectionService) {
Expand All @@ -152,9 +156,20 @@ public void setLocalConnectionService(LocalConnectionService localConnectionServ

@Autowired(required = false)
public void setLocalConnectionServiceMap(Map<String, LocalConnectionService> localConnectionServiceMap) {
if (this.localConnectionService == null && localConnectionServiceMap != null && localConnectionServiceMap.size() > 0) {
this.localConnectionService = localConnectionServiceMap.values().iterator().next();
if (localConnectionServiceMap == null || localConnectionServiceMap.isEmpty()) {
return;
}
this.localConnectionService = choseLocalConnectionService(localConnectionServiceMap);
}

/**
* 选择一个给当前SseWebController用的链接服务
*
* @return 给当前SseWebController用的链接服务 LocalConnectionService
* @since 1.2.16
*/
protected LocalConnectionService choseLocalConnectionService(Map<String, LocalConnectionService> localConnectionServiceMap) {
return localConnectionServiceMap.values().iterator().next();
}

/**
Expand Down Expand Up @@ -229,10 +244,6 @@ protected void onConnect(SseEmitter<ACCESS_USER> conncet, Map<String, Object> qu
disconnectClientIdMaxConnections(conncet, getClientIdMaxConnections());
}

protected void onDisconnect(List<SseEmitter<ACCESS_USER>> disconnectList, ACCESS_USER accessUser, Map query) {

}

protected ResponseEntity buildIfLoginVerifyErrorResponse(ACCESS_USER accessUser,
Map query, Map body,
Long keepaliveTime) {
Expand Down Expand Up @@ -266,7 +277,7 @@ protected Long choseKeepaliveTime(Long clientKeepaliveTime, Long serverKeepalive
*/
@RequestMapping(value = API_CONNECT_STREAM, method = {RequestMethod.GET, RequestMethod.POST})
public Object connect(@RequestParam Map query, @RequestBody(required = false) Map body,
Long keepaliveTime) {
Long keepaliveTime, Long sessionDuration) {
// args
Map<String, Object> attributeMap = new LinkedHashMap<>(query);
if (body != null) {
Expand Down Expand Up @@ -294,6 +305,7 @@ public Object connect(@RequestParam Map query, @RequestBody(required = false) Ma
String channel = Objects.toString(attributeMap.get("channel"), null);
emitter.setChannel(channel == null || channel.isEmpty() ? null : channel);
emitter.setUserAgent(request.getHeader("User-Agent"));
emitter.setSessionDuration(sessionDuration);
emitter.setRequestIp(getRequestIpAddr(request));
emitter.setRequestDomain(getRequestDomain(request));
emitter.setHttpCookies(request.getCookies());
Expand Down Expand Up @@ -433,22 +445,20 @@ public ResponseEntity upload(@PathVariable String path, HttpServletRequest reque
@PostMapping(API_DISCONNECT_DO)
public Object disconnect(Long connectionId, @RequestParam Map query,
Boolean cluster,
@RequestParam(required = false, defaultValue = "5000") Long timeout) {
@RequestParam(required = false, defaultValue = "5000") Long timeout,
Long duration,
Long sessionDuration) {
if (connectionId == null) {
return responseEntity(buildDisconnectResult(0, false));
}
SseEmitter<ACCESS_USER> disconnect = localConnectionService.disconnectByConnectionId(connectionId);
SseEmitter<ACCESS_USER> disconnect = localConnectionService.disconnectByConnectionId(connectionId, duration, sessionDuration);
int localCount = disconnect != null ? 1 : 0;
if (disconnect != null) {
ACCESS_USER currentUser = getAccessUser(API_DISCONNECT_DO);
onDisconnect(Collections.singletonList(disconnect), currentUser, query);
}
if (cluster == null || cluster) {
cluster = localConnectionService.isEnableCluster();
}
if (cluster && localCount == 0) {
DeferredResult<ResponseEntity> result = new DeferredResult<>(timeout, responseEntity(buildDisconnectResult(localCount, true)));
localConnectionService.getCluster().disconnectByConnectionId(connectionId)
localConnectionService.getCluster().disconnectByConnectionId(connectionId, duration, sessionDuration)
.whenComplete((remoteCount, throwable) -> {
if (throwable != null) {
logger.warn("disconnectConnection exception = {}", throwable, throwable);
Expand Down Expand Up @@ -568,21 +578,6 @@ public Object users(@RequestParam(required = false, defaultValue = "1") Integer
return result;
}

/**
* @deprecated v1.2.8
*/
@Deprecated
@GetMapping(API_CONNECTIONS_JSON_V1)
public Object connectionsV1(@RequestParam(required = false, defaultValue = "1") Integer pageNum,
@RequestParam(required = false, defaultValue = "100") Integer pageSize,
String name,
String clientId,
Long id,
Boolean cluster,
@RequestParam(required = false, defaultValue = "5000") Long timeout) {
return connections(pageNum, pageSize, name, clientId, id, cluster, timeout);
}

@GetMapping(API_CONNECTIONS_JSON)
public Object connections(@RequestParam(required = false, defaultValue = "1") Integer pageNum,
@RequestParam(required = false, defaultValue = "100") Integer pageSize,
Expand Down Expand Up @@ -652,11 +647,11 @@ protected ResponseEntity responseEntity(Object responseBody) {
}

protected void settingResponseHeader(HttpHeaders responseHeaders) {
String sseServerIdHeaderName = this.sseServerIdHeaderName;
String sseServerIdHeaderName = getSseServerIdHeaderName();
if (sseServerIdHeaderName != null && sseServerIdHeaderName.length() > 0) {
responseHeaders.set(sseServerIdHeaderName, getSseServerId());
}
responseHeaders.set("Sse-Server-Version", PlatformDependentUtil.SSE_SERVER_VERSION);
responseHeaders.set("X-Sse-Version", PlatformDependentUtil.SSE_SERVER_VERSION);
}

protected String getSseServerId() {
Expand Down Expand Up @@ -755,6 +750,31 @@ protected Object mapToConnectionVO(ConnectionDTO<ACCESS_USER> connectionDTO) {
return connectionDTO;
}

protected String getRequestIpAddr(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
// 如果是多级代理,那么取第一个ip为客户ip
if (ip != null) {
ip = ip.split(",", 2)[0].trim();
}
return ip;
}

protected String getRequestDomain(HttpServletRequest request) {
StringBuffer url = request.getRequestURL();
StringBuffer sb = url.delete(url.length() - request.getRequestURI().length(), url.length());

if (sb.toString().startsWith("http:https://localhost")) {
String host = request.getHeader("host");
if (host != null && !host.isEmpty()) {
sb = new StringBuffer("http:https://" + host);
}
}
return WebUtil.rewriteHttpToHttpsIfSecure(sb.toString(), request.isSecure());
}

public static class ListenerReq {
private List<String> listener;
private Long connectionId;
Expand Down Expand Up @@ -949,31 +969,6 @@ public void setErrorMessage(String errorMessage) {
}
}

protected String getRequestIpAddr(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
// 如果是多级代理,那么取第一个ip为客户ip
if (ip != null) {
ip = ip.split(",", 2)[0].trim();
}
return ip;
}

protected String getRequestDomain(HttpServletRequest request) {
StringBuffer url = request.getRequestURL();
StringBuffer sb = url.delete(url.length() - request.getRequestURI().length(), url.length());

if (sb.toString().startsWith("http:https://localhost")) {
String host = request.getHeader("host");
if (host != null && !host.isEmpty()) {
sb = new StringBuffer("http:https://" + host);
}
}
return WebUtil.rewriteHttpToHttpsIfSecure(sb.toString(), request.isSecure());
}

private static class ClusterBatchDisconnectRunnable implements Runnable {
private final Collection<Long> batchDisconnectIdList = Collections.newSetFromMap(new ConcurrentHashMap<>());
private final Supplier<ClusterConnectionService> serviceSupplier;
Expand Down Expand Up @@ -1010,4 +1005,16 @@ public void run() {
}
}

public static long copy(InputStream source, OutputStream sink, int bufferSize)
throws IOException {
long nread = 0L;
byte[] buf = new byte[bufferSize];
int n;
while ((n = source.read(buf)) > 0) {
sink.write(buf, 0, n);
nread += n;
}
return nread;
}

}
Loading