feat: 暂存代码,还有很多路

This commit is contained in:
wangyu 2024-10-12 00:32:05 +08:00
parent 4a2cc01993
commit a1cd8a2ba2
3 changed files with 266 additions and 213 deletions

View File

@ -0,0 +1,111 @@
package dev.flyfish.boot.cas.filter;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.WebSessionManager;
import org.springframework.web.server.session.WebSessionStore;
import reactor.core.publisher.Mono;
import java.util.Map;
/**
* cas 上下文
*/
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
class CASContext {
@Getter
private String ticket;
@Setter
@Getter
private WebSession session;
private final ServerWebExchange exchange;
private final WebFilterChain chain;
private final Map<String, Mono<String>> parameters = new ConcurrentReferenceHashMap<>();
static Mono<CASContext> create(ServerWebExchange exchange, WebFilterChain chain) {
return new CASContext(exchange, chain).init();
}
private Mono<CASContext> init() {
Mono<String> ticketMono = getParameter("ticket")
.filter(StringUtils::hasText)
.doOnNext(ticket -> this.ticket = ticket);
// 此处必须保证session不为空
Mono<WebSession> sessionMono = exchange.getSession()
.doOnNext(session -> this.session = session);
return Mono.zipDelayError(ticketMono, sessionMono)
.thenReturn(this);
}
Mono<Void> filter() {
return chain.filter(exchange);
}
/**
* 获取参数
*
* @param key
* @return 异步结果
*/
Mono<String> getParameter(String key) {
return parameters.computeIfAbsent(key, this::computeParameter);
}
ServerHttpRequest getRequest() {
return exchange.getRequest();
}
ServerHttpResponse getResponse() {
return exchange.getResponse();
}
String getPath() {
return exchange.getRequest().getPath().value();
}
HttpMethod getMethod() {
return exchange.getRequest().getMethod();
}
String getQuery(String key) {
ServerHttpRequest request = exchange.getRequest();
return request.getQueryParams().getFirst(key);
}
Mono<String> getFormData(String key) {
return exchange.getFormData()
.mapNotNull(formData -> formData.getFirst(key));
}
private Mono<String> computeParameter(String key) {
return this.readParameter(key).cache();
}
private Mono<String> readParameter(String key) {
String query = getQuery(key);
if (StringUtils.hasText(query)) {
return Mono.just(query);
}
MediaType mediaType = exchange.getRequest().getHeaders().getContentType();
if (null != mediaType && mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) {
return exchange.getFormData().mapNotNull(formData -> formData.getFirst(key));
}
return Mono.empty();
}
}

View File

@ -2,15 +2,12 @@ package dev.flyfish.boot.cas.filter;
import edu.yale.its.tp.cas.client.*; import edu.yale.its.tp.cas.client.*;
import edu.yale.its.tp.cas.util.XmlUtils; import edu.yale.its.tp.cas.util.XmlUtils;
import org.apache.commons.logging.Log; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.NonNull; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
@ -22,8 +19,8 @@ import java.io.IOException;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import java.util.Objects;
/** /**
* cas filter的webflux实现 * cas filter的webflux实现
@ -31,10 +28,9 @@ import java.util.Objects;
* @author wangyu * @author wangyu
* 实现相关核心逻辑完成鉴权信息抽取 * 实现相关核心逻辑完成鉴权信息抽取
*/ */
@Slf4j
public class CASFilter implements WebFilter { public class CASFilter implements WebFilter {
private static final Log log = LogFactory.getLog(CASFilter.class);
public static final String LOGIN_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.loginUrl"; public static final String LOGIN_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.loginUrl";
public static final String VALIDATE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.validateUrl"; public static final String VALIDATE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.validateUrl";
public static final String SERVICE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.serviceUrl"; public static final String SERVICE_INIT_PARAM = "edu.yale.its.tp.cas.client.filter.serviceUrl";
@ -54,7 +50,7 @@ public class CASFilter implements WebFilter {
private final CASParameter parameter; private final CASParameter parameter;
private static SessionMappingStorage SESSION_MAPPING_STORAGE = new HashMapBackedSessionMappingStorage(); private static SessionMappingStorage SESSION_MAPPING_STORAGE = new SessionMappingStorage.HashMapBackedSessionStorage();
private List<HttpMessageReader<?>> messageReaders; private List<HttpMessageReader<?>> messageReaders;
@ -74,57 +70,61 @@ public class CASFilter implements WebFilter {
} }
} }
private CASReceipt getAuthenticatedUser(ServerWebExchange exchange, String ticket) throws CASAuthenticationException { private CASReceipt getAuthenticatedUser(ServerHttpRequest request, String ticket) throws CASAuthenticationException {
log.trace("entering getAuthenticatedUser()"); log.trace("entering getAuthenticatedUser()");
ProxyTicketValidator pv = new ProxyTicketValidator(); ProxyTicketValidator pv = new ProxyTicketValidator();
pv.setCasValidateUrl(parameter.casValidate); pv.setCasValidateUrl(parameter.casValidate);
pv.setServiceTicket(ticket); pv.setServiceTicket(ticket);
pv.setService(this.getService(exchange.getRequest())); pv.setService(this.getService(request));
pv.setRenew(parameter.casRenew); pv.setRenew(parameter.casRenew);
if (parameter.casProxyCallbackUrl != null) { if (parameter.casProxyCallbackUrl != null) {
pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl); pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl);
} }
if (log.isDebugEnabled()) { log.debug("about to validate ProxyTicketValidator: [{}]", pv);
log.debug("about to validate ProxyTicketValidator: [" + pv + "]");
}
return CASReceipt.getReceipt(pv); return CASReceipt.getReceipt(pv);
} }
private String getService(ServerHttpRequest request) throws ServletException { private String getService(ServerHttpRequest request) {
log.trace("entering getService()"); log.trace("entering getService()");
if (this.casServerName == null && this.casServiceUrl == null) { if (parameter.casServerName == null && parameter.casServiceUrl == null) {
throw new ServletException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName"); throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName");
} else { } else {
String serviceString; String serviceString;
if (this.casServiceUrl != null) { if (parameter.casServiceUrl != null) {
serviceString = URLEncoder.encode(this.casServiceUrl); serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8);
} else { } else {
serviceString = Util.getService(request, this.casServerName); serviceString = Util.getService(request, parameter.casServerName);
} }
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
log.trace("returning from getService() with service [" + serviceString + "]"); log.trace("returning from getService() with service [{}]", serviceString);
} }
return serviceString; return serviceString;
} }
} }
private void redirectToCAS(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { /**
if (log.isTraceEnabled()) { * 核心跳转cas服务器鉴权
log.trace("entering redirectToCAS()"); *
} * @param context 上下文
* @return 结果
* @throws IOException 异常
*/
private Mono<Void> redirectToCAS(CASContext context) {
ServerHttpRequest request = context.getRequest();
ServerHttpResponse response = context.getResponse();
String sessionId = context.getSession().getId();
log.trace("entering redirectToCAS()");
String casLoginString = this.casLogin + "?service=" + this.getService(request) + (this.casRenew ? "&renew=true" : "") + (this.casGateway ? "&gateway=true" : ""); String casLoginString = parameter.casLogin + "?service=" + this.getService(request) +
(parameter.casRenew ? "&renew=true" : "") + (parameter.casGateway ? "&gateway=true" : "");
String sCookie; String sCookie;
if (request.getAttribute("sessionId") != null) { sCookie = parameter.casServerName + request.getPath().contextPath().value();
sCookie = this.casServerName + request.getContextPath(); casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + sessionId;
casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + request.getAttribute("sessionId");
}
sCookie = request.getHeader("Cookie"); sCookie = request.getHeaders().getFirst("Cookie");
String cookie = null; String cookie = null;
if (sCookie != null) { if (sCookie != null) {
String[] sCookies = sCookie.split(";"); String[] sCookies = sCookie.split(";");
@ -139,7 +139,7 @@ public class CASFilter implements WebFilter {
if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) { if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) {
casLoginString = casLoginString + "&timeOut=" + cookie; casLoginString = casLoginString + "&timeOut=" + cookie;
if (log.isDebugEnabled()) { if (log.isDebugEnabled()) {
log.debug("Session is timeout. The timeout session is " + cookie); log.debug("Session is timeout. The timeout session is {}", cookie);
} }
} }
@ -155,9 +155,7 @@ public class CASFilter implements WebFilter {
} }
private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException { private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException {
if (log.isTraceEnabled()) { log.trace("entering redirectToInitFailure()");
log.trace("entering redirectToInitFailure()");
}
String casLoginString = this.casLogin + "?action=initFailure"; String casLoginString = this.casLogin + "?action=initFailure";
if (cause != null && cause.equals("Illegal user")) { if (cause != null && cause.equals("Illegal user")) {
@ -169,40 +167,29 @@ public class CASFilter implements WebFilter {
casLoginString = casLoginString + "&locale=" + locale; casLoginString = casLoginString + "&locale=" + locale;
} }
if (log.isDebugEnabled()) { log.debug("Redirecting browser to [" + casLoginString + ")");
log.debug("Redirecting browser to [" + casLoginString + ")");
}
response.sendRedirect(casLoginString); response.sendRedirect(casLoginString);
if (log.isTraceEnabled()) { log.trace("returning from redirectToInitFailure()");
log.trace("returning from redirectToInitFailure()");
}
} }
public static SessionMappingStorage getSessionMappingStorage() { public static SessionMappingStorage getSessionMappingStorage() {
return SESSION_MAPPING_STORAGE; return SESSION_MAPPING_STORAGE;
} }
private boolean isExclusion(ServerHttpRequest request) { private boolean isExclusion(String url) {
if (parameter.exclusions == null) { if (parameter.exclusions == null) {
return false; return false;
} else { } else {
String url = request.getPath().value();
return parameter.exclusions.contains(url); return parameter.exclusions.contains(url);
} }
} }
private Mono<Void> translate(ServerWebExchange exchange, WebFilterChain chain, WebSession session) { private Mono<Void> translate(CASContext context) {
ServerHttpRequest request = exchange.getRequest(); if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath())
MultiValueMap<String, String> params = request.getQueryParams(); && context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) {
String userName;
String artifact;
if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(request.getPath().value())
&& params.getFirst("pgtId") != null && params.getFirst("pgtIou") != null) {
log.trace("passing through what we hope is CAS's request for proxy ticket receptor."); log.trace("passing through what we hope is CAS's request for proxy ticket receptor.");
return chain.filter(exchange); return context.filter();
} else { } else {
if (parameter.wrapRequest) { if (parameter.wrapRequest) {
log.trace("Wrapping request with CASFilterRequestWrapper."); log.trace("Wrapping request with CASFilterRequestWrapper.");
@ -210,122 +197,104 @@ public class CASFilter implements WebFilter {
// request = new CASFilterRequestWrapper((HttpServletRequest) request); // request = new CASFilterRequestWrapper((HttpServletRequest) request);
} }
WebSession session = context.getSession();
// 使用了用户标记快速跳过 // 使用了用户标记快速跳过
if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) { if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) {
return chain.filter(exchange); return context.filter();
} }
// 获取receipt // 获取receipt
CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT); CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT);
if (receipt != null && this.isReceiptAcceptable(receipt)) { if (receipt != null && this.isReceiptAcceptable(receipt)) {
log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter.."); log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter..");
return chain.filter(exchange); return context.filter();
} }
// 跳过请求 // 跳过请求
if (this.isExclusion(request)) { if (this.isExclusion(context.getPath())) {
return chain.filter(exchange); return context.filter();
} }
return getParameter(exchange, "ticket") // 判断票据
.flatMap(ticket -> { String ticket = context.getTicket();
if (StringUtils.hasText(ticket)) { if (StringUtils.hasText(ticket)) {
try { try {
receipt = this.getAuthenticatedUser(exchange, ticket); receipt = this.getAuthenticatedUser(context.getRequest(), ticket);
} catch (CASAuthenticationException var22) { } catch (CASAuthenticationException var22) {
((HttpServletRequest) request).setAttribute("sessionId", session.getId()); return this.redirectToCAS(context);
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); }
return;
} if (!this.isReceiptAcceptable(receipt)) {
throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]");
if (!this.isReceiptAcceptable(receipt)) { } else {
throw new ServletException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); String pt = context.getQuery("pt");
} else { if (StringUtils.hasText(pt)) {
if (pt != null && pt != "") { session.getAttributes().put(pt, receipt);
session.setAttribute(pt, receipt); }
}
String userName = receipt.getUserName();
if (session != null) { if (StringUtils.hasText(parameter.casInitContextClass)) {
userName = receipt.getUserName(); try {
if (this.casInitContextClass != null && !"".equals(this.casInitContextClass)) { Class<?> cls = Class.forName(parameter.casInitContextClass);
try { Object obj = cls.getConstructor().newInstance();
Class cls = Class.forName(this.casInitContextClass); if (obj instanceof IContextInit) {
Object obj = cls.newInstance(); Method translatorMethod = cls.getMethod("getTranslatorUser", String.class);
if (obj instanceof IContextInit) { userName = (String) translatorMethod.invoke(obj, userName);
Method translatorMethod = cls.getMethod("getTranslatorUser", String.class); Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class);
userName = (String) translatorMethod.invoke(obj, userName); initContextMethod.invoke(obj, request, response, fc, userName);
Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class);
initContextMethod.invoke(obj, request, response, fc, userName);
}
} catch (ClassNotFoundException var15) {
ClassNotFoundException e = var15;
e.printStackTrace();
} catch (InstantiationException var16) {
InstantiationException e = var16;
e.printStackTrace();
} catch (IllegalAccessException var17) {
IllegalAccessException e = var17;
e.printStackTrace();
} catch (IllegalArgumentException var18) {
IllegalArgumentException e = var18;
e.printStackTrace();
} catch (InvocationTargetException var19) {
InvocationTargetException e = var19;
String cause = e.getCause().getMessage();
session.setAttribute("initFailure", cause);
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
e.printStackTrace();
return;
} catch (SecurityException var20) {
SecurityException e = var20;
e.printStackTrace();
} catch (NoSuchMethodException var21) {
NoSuchMethodException e = var21;
e.printStackTrace();
}
}
session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName);
session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt);
session.removeAttribute("edu.yale.its.tp.cas.client.filter.didGateway");
}
if (log.isTraceEnabled()) {
log.trace("validated ticket to get authenticated receipt [" + receipt + "], now passing request along filter chain.");
}
fc.doFilter((ServletRequest) request, response);
log.trace("returning from doFilter()");
}
} else {
log.trace("CAS ticket was not present on request.");
boolean didGateway = Boolean.valueOf((String) session.getAttribute("edu.yale.its.tp.cas.client.filter.didGateway"));
if (this.casLogin == null) {
log.fatal("casLogin was not set, so filter cannot redirect request for authentication.");
throw new ServletException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter");
} else if (!didGateway) {
log.trace("Did not previously gateway. Setting session attribute to true.");
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute("edu.yale.its.tp.cas.client.filter.didGateway", "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
} else {
log.trace("Previously gatewayed.");
if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) {
if (session.getAttribute("initFailure") != null) {
String cause = (String) session.getAttribute("initFailure");
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
} else {
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute("edu.yale.its.tp.cas.client.filter.didGateway", "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
}
} else {
log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain.");
fc.doFilter((ServletRequest) request, response);
}
} }
} catch (ClassNotFoundException | IllegalArgumentException | IllegalAccessException |
InstantiationException | SecurityException | NoSuchMethodException e) {
e.printStackTrace();
} catch (InvocationTargetException var19) {
InvocationTargetException e = var19;
String cause = e.getCause().getMessage();
session.setAttribute("initFailure", cause);
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
e.printStackTrace();
return Mono.empty();
} }
}) }
session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName);
session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt);
session.removeAttribute(CAS_FILTER_GATEWAYED);
log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt);
log.trace("returning from doFilter()");
return context.filter();
}
return Mono.empty();
} else {
log.trace("CAS ticket was not present on request.");
boolean didGateway = Boolean.valueOf((String) session.getAttribute(CAS_FILTER_GATEWAYED));
if (parameter.casLogin == null) {
log.error("casLogin was not set, so filter cannot redirect request for authentication.");
throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter");
} else if (!didGateway) {
log.trace("Did not previously gateway. Setting session attribute to true.");
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute(CAS_FILTER_GATEWAYED, "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
} else {
log.trace("Previously gatewayed.");
if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) {
if (session.getAttribute("initFailure") != null) {
String cause = (String) session.getAttribute("initFailure");
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause);
} else {
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute(CAS_FILTER_GATEWAYED, "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response);
}
} else {
log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain.");
fc.doFilter((ServletRequest) request, response);
}
}
}
} }
} }
@ -333,94 +302,67 @@ public class CASFilter implements WebFilter {
/** /**
* 二阶段处理预处理特殊情况提前中断请求 * 二阶段处理预处理特殊情况提前中断请求
* *
* @param exchange 交换信息 * @param context 上下文工具
* @param chain 过滤器链
* @param session 会话
* @return 结果 * @return 结果
*/ */
private Mono<Void> handle(ServerWebExchange exchange, WebFilterChain chain, @NonNull WebSession session) { private Mono<Void> handle(CASContext context) {
// 下一步处理信号提前生成
Mono<Void> translate = this.translate(exchange, chain, session);
// post请求需要特殊处理 // post请求需要特殊处理
if (exchange.getRequest().getMethod() == HttpMethod.POST) { if (context.getMethod() == HttpMethod.POST) {
// 此处可能要求安全的获取参数单独针对退出请求 // 此处可能要求安全的获取参数单独针对退出请求
return exchange.getFormData() return context.getFormData("logoutRequest")
.flatMap(formData -> { .doOnNext(payload -> log.trace("Logout request=[{}]", payload))
String payload = formData.getFirst("logoutRequest"); .defaultIfEmpty("")
.flatMap(payload -> {
if (StringUtils.hasText(payload)) { if (StringUtils.hasText(payload)) {
if (log.isTraceEnabled()) {
log.trace("Logout request=[" + payload + "]");
}
String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex"); String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex");
if (StringUtils.hasText(sessionIdentifier)) { if (StringUtils.hasText(sessionIdentifier)) {
// 命中该请求中断执行 // 满足条件时断路
return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier) return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier)
.filter(Objects::nonNull) .doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), sessionIdentifier))
.flatMap(savedSession -> { .flatMap(WebSession::invalidate)
String sessionId = savedSession.getId(); .doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e))
if (log.isDebugEnabled()) { .onErrorComplete();
log.debug("Invalidating session [" + sessionId + "] for ST [" + sessionIdentifier + "]");
}
try {
return savedSession.invalidate();
} catch (IllegalStateException e) {
log.debug(e, e);
}
// 中断处理
return Mono.empty();
});
} }
} }
return translate; // 继续执行
return translate(context);
}); });
} else { } else {
String ticket = exchange.getRequest().getQueryParams().getFirst("ticket"); String ticket = context.getTicket();
if (log.isDebugEnabled()) { String sessionId = context.getSession().getId();
log.debug("Storing session identifier for " + session.getId()); log.debug("Storing session identifier for {}", sessionId);
}
// 包括ticket尝试重新替换session // 包括ticket尝试重新替换session
if (StringUtils.hasText(ticket)) { if (StringUtils.hasText(ticket)) {
return SESSION_MAPPING_STORAGE.removeBySessionById(session.getId()) return SESSION_MAPPING_STORAGE.removeBySessionById(sessionId)
.onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, session)) .onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, context.getSession()))
.then(translate); .then(Mono.defer(() -> translate(context)));
} }
return translate; return translate(context);
} }
} }
private Mono<String> getParameter(ServerWebExchange exchange, String key) {
ServerHttpRequest request = exchange.getRequest();
String query = request.getQueryParams().getFirst(key);
if (StringUtils.hasText(query)) {
return Mono.just(query);
}
MediaType mediaType = request.getHeaders().getContentType();
if (null != mediaType && mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) {
return exchange.getFormData().mapNotNull(formData -> formData.getFirst(key));
}
return Mono.empty();
}
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
// 拦截器需要基于session判定故提前使用 // 拦截器需要基于session判定故提前使用
return exchange.getSession() return CASContext.create(exchange, chain)
.flatMap(session -> { .flatMap(context -> {
WebSession session = context.getSession();
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
log.trace("entering doFilter()"); log.trace("entering doFilter()");
} }
// 执行中断策略 // 执行中断策略
String pt = exchange.getRequest().getQueryParams().getFirst("pt"); String pt = context.getQuery("pt");
if (StringUtils.hasText(pt)) { if (StringUtils.hasText(pt)) {
if (session.getAttribute(pt) != null) { if (session.getAttribute(pt) != null) {
return chain.filter(exchange); return context.filter();
} }
} }
return handle(exchange, chain, session); return handle(context);
}); });
} }
} }

View File

@ -30,9 +30,9 @@ public interface SessionMappingStorage {
public Mono<WebSession> removeSessionByMappingId(String mappingId) { public Mono<WebSession> removeSessionByMappingId(String mappingId) {
WebSession session = this.MANAGED_SESSIONS.get(mappingId); WebSession session = this.MANAGED_SESSIONS.get(mappingId);
if (session != null) { if (session != null) {
this.removeBySessionById(session.getId()); return this.removeBySessionById(session.getId()).thenReturn(session);
} }
return Mono.just(session); return Mono.empty();
} }
@Override @Override