feat: 初步完成实现

This commit is contained in:
wangyu 2024-10-12 15:35:08 +08:00
parent a1cd8a2ba2
commit b978c776b5
11 changed files with 645 additions and 210 deletions

View File

@ -0,0 +1,91 @@
package dev.flyfish.boot.cas.config;
import dev.flyfish.boot.cas.config.session.WebSessionDecorator;
import dev.flyfish.boot.cas.config.session.WebSessionListener;
import dev.flyfish.boot.cas.filter.CASFilter;
import dev.flyfish.boot.cas.filter.CASParameter;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.DefaultWebSessionManager;
import org.springframework.web.server.session.InMemoryWebSessionStore;
import org.springframework.web.server.session.WebSessionManager;
import org.springframework.web.server.session.WebSessionStore;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.util.List;
/**
* cas核心配置
*
* @author wangyu
*/
@Configuration
public class CASConfig {
@Bean
@ConfigurationProperties("cas.filter")
public CASParameter casParameter() {
return new CASParameter();
}
@Bean
public CASFilter casFilter(CASParameter casParameter) {
return new CASFilter(casParameter);
}
@Bean
public WebSessionStore webSessionStore(WebSessionManager webSessionManager, ServerProperties serverProperties,
ObjectProvider<WebSessionListener> listeners) {
if (webSessionManager instanceof DefaultWebSessionManager defaultWebSessionManager) {
Duration timeout = serverProperties.getReactive().getSession().getTimeout();
int maxSessions = serverProperties.getReactive().getSession().getMaxSessions();
ListenableWebSessionStore sessionStore = new ListenableWebSessionStore(timeout, listeners);
sessionStore.setMaxSessions(maxSessions);
defaultWebSessionManager.setSessionStore(sessionStore);
return sessionStore;
}
throw new IllegalStateException("Cannot find web session manager");
}
/**
* 处理session销毁保证正确退出
*
* @param casFilter 过滤器
* @return 结果
*/
@Bean
public WebSessionListener singleSignOutSessionListener(CASFilter casFilter) {
return new WebSessionListener() {
@Override
public Mono<Void> onSessionInvalidated(WebSession session) {
return casFilter.getSessionMappingStorage().removeBySessionById(session.getId());
}
};
}
static final class ListenableWebSessionStore extends InMemoryWebSessionStore {
private final Duration timeout;
private final List<WebSessionListener> listeners;
private ListenableWebSessionStore(Duration timeout, ObjectProvider<WebSessionListener> listeners) {
this.timeout = timeout;
this.listeners = listeners.stream().toList();
}
public Mono<WebSession> createWebSession() {
return super.createWebSession()
.map(session -> (WebSession) new WebSessionDecorator(session, listeners))
.doOnSuccess(this::setMaxIdleTime);
}
private void setMaxIdleTime(WebSession session) {
session.setMaxIdleTime(this.timeout);
}
}
}

View File

@ -0,0 +1,194 @@
package dev.flyfish.boot.cas.config.session;
import lombok.RequiredArgsConstructor;
import org.springframework.lang.Nullable;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class WebSessionDecorator implements WebSession {
private final WebSession decorated;
private final List<WebSessionListener> listeners;
/**
* Return a unique session identifier.
*/
@Override
public String getId() {
return decorated.getId();
}
/**
* Return a map that holds session attributes.
*/
@Override
public Map<String, Object> getAttributes() {
return decorated.getAttributes();
}
/**
* Return the session attribute value if present.
*
* @param name the attribute name
* @return the attribute value
*/
@Nullable
@Override
public <T> T getAttribute(String name) {
return decorated.getAttribute(name);
}
/**
* Return the session attribute value or if not present raise an
* {@link IllegalArgumentException}.
*
* @param name the attribute name
* @return the attribute value
*/
@Override
public <T> T getRequiredAttribute(String name) {
return decorated.getRequiredAttribute(name);
}
/**
* Return the session attribute value, or a default, fallback value.
*
* @param name the attribute name
* @param defaultValue a default value to return instead
* @return the attribute value
*/
@Override
public <T> T getAttributeOrDefault(String name, T defaultValue) {
return decorated.getAttributeOrDefault(name, defaultValue);
}
/**
* Force the creation of a session causing the session id to be sent when
* {@link #save()} is called.
*/
@Override
public void start() {
decorated.start();
}
/**
* Whether a session with the client has been started explicitly via
* {@link #start()} or implicitly by adding session attributes.
* If "false" then the session id is not sent to the client and the
* {@link #save()} method is essentially a no-op.
*/
@Override
public boolean isStarted() {
return decorated.isStarted();
}
/**
* Generate a new id for the session and update the underlying session
* storage to reflect the new id. After a successful call {@link #getId()}
* reflects the new session id.
*
* @return completion notification (success or error)
*/
@Override
public Mono<Void> changeSessionId() {
return decorated.changeSessionId();
}
/**
* Invalidate the current session and clear session storage.
*
* @return completion notification (success or error)
*/
@Override
public Mono<Void> invalidate() {
// 后续处理
Mono<Void> consumer = Mono.defer(() -> listeners.stream()
.map(listener -> listener.onSessionInvalidated(this.decorated))
.reduce(Mono::then)
.orElse(Mono.empty()));
return decorated.invalidate().then(consumer);
}
/**
* Save the session through the {@code WebSessionStore} as follows:
* <ul>
* <li>If the session is new (i.e. created but never persisted), it must have
* been started explicitly via {@link #start()} or implicitly by adding
* attributes, or otherwise this method should have no effect.
* <li>If the session was retrieved through the {@code WebSessionStore},
* the implementation for this method must check whether the session was
* {@link #invalidate() invalidated} and if so return an error.
* </ul>
* <p>Note that this method is not intended for direct use by applications.
* Instead it is automatically invoked just before the response is
* committed.
*
* @return {@code Mono} to indicate completion with success or error
*/
@Override
public Mono<Void> save() {
return decorated.save();
}
/**
* Return {@code true} if the session expired after {@link #getMaxIdleTime()
* maxIdleTime} elapsed.
* <p>Typically expiration checks should be automatically made when a session
* is accessed, a new {@code WebSession} instance created if necessary, at
* the start of request processing so that applications don't have to worry
* about expired session by default.
*/
@Override
public boolean isExpired() {
return decorated.isExpired();
}
/**
* Return the time when the session was created.
*/
@Override
public Instant getCreationTime() {
return decorated.getCreationTime();
}
/**
* Return the last time of session access as a result of user activity such
* as an HTTP request. Together with {@link #getMaxIdleTime()
* maxIdleTimeInSeconds} this helps to determine when a session is
* {@link #isExpired() expired}.
*/
@Override
public Instant getLastAccessTime() {
return decorated.getLastAccessTime();
}
/**
* Configure the max amount of time that may elapse after the
* {@link #getLastAccessTime() lastAccessTime} before a session is considered
* expired. A negative value indicates the session should not expire.
*
* @param maxIdleTime
*/
@Override
public void setMaxIdleTime(Duration maxIdleTime) {
decorated.setMaxIdleTime(maxIdleTime);
}
/**
* Return the maximum time after the {@link #getLastAccessTime()
* lastAccessTime} before a session expires. A negative time indicates the
* session doesn't expire.
*/
@Override
public Duration getMaxIdleTime() {
return decorated.getMaxIdleTime();
}
}

View File

@ -0,0 +1,21 @@
package dev.flyfish.boot.cas.config.session;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
/**
* web session监听器
*
* @author wangyu
* 基于装饰器增强实现可灵活处理
*/
public interface WebSessionListener {
default Mono<Void> onSessionCreated(WebSession session) {
return Mono.empty();
}
default Mono<Void> onSessionInvalidated(WebSession session) {
return Mono.empty();
}
}

View File

@ -0,0 +1,16 @@
package dev.flyfish.boot.cas.controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.HashMap;
@RestController
public class IndexController {
@GetMapping("/hello")
public Object index() {
return new HashMap<>();
}
}

View File

@ -13,17 +13,16 @@ import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.WebSessionManager;
import org.springframework.web.server.session.WebSessionStore;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.net.URI;
import java.util.Map; import java.util.Map;
/** /**
* cas 上下文 * cas 上下文
*/ */
@RequiredArgsConstructor(access = AccessLevel.PRIVATE) @RequiredArgsConstructor(access = AccessLevel.PRIVATE)
class CASContext { public class CASContext {
@Getter @Getter
private String ticket; private String ticket;
@ -38,6 +37,10 @@ class CASContext {
private final Map<String, Mono<String>> parameters = new ConcurrentReferenceHashMap<>(); private final Map<String, Mono<String>> parameters = new ConcurrentReferenceHashMap<>();
@Setter
@Getter
private String username;
static Mono<CASContext> create(ServerWebExchange exchange, WebFilterChain chain) { static Mono<CASContext> create(ServerWebExchange exchange, WebFilterChain chain) {
return new CASContext(exchange, chain).init(); return new CASContext(exchange, chain).init();
} }
@ -50,21 +53,20 @@ class CASContext {
Mono<WebSession> sessionMono = exchange.getSession() Mono<WebSession> sessionMono = exchange.getSession()
.doOnNext(session -> this.session = session); .doOnNext(session -> this.session = session);
return Mono.zipDelayError(ticketMono, sessionMono) return Mono.zipDelayError(ticketMono, sessionMono)
.onErrorContinue((e, v) -> e.printStackTrace())
.thenReturn(this); .thenReturn(this);
} }
boolean isTokenRequest() {
return StringUtils.hasText(ticket);
}
Mono<Void> filter() { Mono<Void> filter() {
return chain.filter(exchange); return chain.filter(exchange);
} }
/** Mono<Void> redirect(String url) {
* 获取参数 return chain.filter(exchange.mutate().request(builder -> builder.uri(URI.create(url))).build());
*
* @param key
* @return 异步结果
*/
Mono<String> getParameter(String key) {
return parameters.computeIfAbsent(key, this::computeParameter);
} }
ServerHttpRequest getRequest() { ServerHttpRequest getRequest() {
@ -93,6 +95,20 @@ class CASContext {
.mapNotNull(formData -> formData.getFirst(key)); .mapNotNull(formData -> formData.getFirst(key));
} }
void setSessionAttribute(String key, Object value) {
session.getAttributes().put(key, value);
}
/**
* 获取参数
*
* @param key
* @return 异步结果
*/
private Mono<String> getParameter(String key) {
return parameters.computeIfAbsent(key, this::computeParameter);
}
private Mono<String> computeParameter(String key) { private Mono<String> computeParameter(String key) {
return this.readParameter(key).cache(); return this.readParameter(key).cache();
} }

View File

@ -0,0 +1,13 @@
package dev.flyfish.boot.cas.filter;
/**
* 上下文初始化逻辑
*
* @author wangyu
*/
public interface CASContextInit {
String getTranslatorUser(String username);
void initContext(CASContext casContext, String username);
}

View File

@ -1,13 +1,15 @@
package dev.flyfish.boot.cas.filter; package dev.flyfish.boot.cas.filter;
import edu.yale.its.tp.cas.client.*; import edu.yale.its.tp.cas.client.CASAuthenticationException;
import edu.yale.its.tp.cas.client.CASReceipt;
import edu.yale.its.tp.cas.client.ProxyTicketValidator;
import edu.yale.its.tp.cas.util.XmlUtils; import edu.yale.its.tp.cas.util.XmlUtils;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
@ -15,12 +17,13 @@ import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.net.URI;
import java.net.URLEncoder; import java.net.URLEncoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects;
/** /**
* cas filter的webflux实现 * cas filter的webflux实现
@ -49,15 +52,32 @@ public class CASFilter implements WebFilter {
static final String CAS_FILTER_EXCLUSION = "edu.yale.its.tp.cas.client.filter.filterExclusion"; static final String CAS_FILTER_EXCLUSION = "edu.yale.its.tp.cas.client.filter.filterExclusion";
private final CASParameter parameter; private final CASParameter parameter;
private final CASContextInit initializer;
@Getter
private final SessionMappingStorage sessionMappingStorage = new SessionMappingStorage.HashMapBackedSessionStorage();
private static SessionMappingStorage SESSION_MAPPING_STORAGE = new SessionMappingStorage.HashMapBackedSessionStorage(); public CASFilter(CASParameter parameter) {
this.parameter = parameter.checked();
this.initializer = createInitializer();
}
private List<HttpMessageReader<?>> messageReaders; private CASContextInit createInitializer() {
if (null != parameter.casInitContextClass) {
public CASFilter(CASParameter parameter, ServerCodecConfigurer codecConfigurer) { try {
parameter.check(); // 未正确配置类型抛弃
this.parameter = parameter; Class<? extends CASContextInit> cls = parameter.casInitContextClass.asSubclass(CASContextInit.class);
this.messageReaders = codecConfigurer.getReaders(); // 实例化对象并返回
return cls.getConstructor().newInstance();
} catch (ClassCastException e) {
log.warn("cas context init class not implements CASContextInit", e);
} catch (IllegalArgumentException | IllegalAccessException | InstantiationException
| SecurityException | NoSuchMethodException e) {
log.error("error when initialize the context init class", e);
} catch (InvocationTargetException e) {
log.error("error when create the cas context initializer's instance!", e);
}
}
return null;
} }
private boolean isReceiptAcceptable(CASReceipt receipt) { private boolean isReceiptAcceptable(CASReceipt receipt) {
@ -70,12 +90,12 @@ public class CASFilter implements WebFilter {
} }
} }
private CASReceipt getAuthenticatedUser(ServerHttpRequest request, String ticket) throws CASAuthenticationException { private CASReceipt getAuthenticatedUser(CASContext context) throws CASAuthenticationException {
log.trace("entering getAuthenticatedUser()"); log.trace("entering getAuthenticatedUser()");
ProxyTicketValidator pv = new ProxyTicketValidator(); ProxyTicketValidator pv = new ProxyTicketValidator();
pv.setCasValidateUrl(parameter.casValidate); pv.setCasValidateUrl(parameter.casValidate);
pv.setServiceTicket(ticket); pv.setServiceTicket(context.getTicket());
pv.setService(this.getService(request)); pv.setService(this.getService(context));
pv.setRenew(parameter.casRenew); pv.setRenew(parameter.casRenew);
if (parameter.casProxyCallbackUrl != null) { if (parameter.casProxyCallbackUrl != null) {
pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl); pv.setProxyCallbackUrl(parameter.casProxyCallbackUrl);
@ -86,16 +106,18 @@ public class CASFilter implements WebFilter {
return CASReceipt.getReceipt(pv); return CASReceipt.getReceipt(pv);
} }
private String getService(ServerHttpRequest request) { private String getService(CASContext context) {
log.trace("entering getService()"); log.trace("entering getService()");
if (parameter.casServerName == null && parameter.casServiceUrl == null) { if (parameter.casServerName == null && parameter.casServiceUrl == null) {
throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName"); throw new IllegalArgumentException("need one of the following configuration parameters: edu.yale.its.tp.cas.client.filter.serviceUrl or edu.yale.its.tp.cas.client.filter.serverName");
} else { }
String serviceString; String serviceString;
if (parameter.casServiceUrl != null) { if (parameter.casServiceUrl != null) {
serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8); serviceString = URLEncoder.encode(parameter.casServiceUrl, StandardCharsets.UTF_8);
} else { } else {
serviceString = Util.getService(request, parameter.casServerName); serviceString = computeService(context, parameter.casServerName);
} }
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
@ -103,6 +125,52 @@ public class CASFilter implements WebFilter {
} }
return serviceString; return serviceString;
} }
/**
* 计算服务地址主要是替换url中server的部分并去除ticket
*
* @param context 上下文
* @param server 服务
* @return 结果
*/
public static String computeService(CASContext context, String server) {
if (log.isTraceEnabled()) {
log.trace("entering getService({}, {})", context, server);
}
if (server == null) {
log.error("getService() argument \"server\" was illegally null.");
throw new IllegalArgumentException("name of server is required");
}
URI uri = context.getRequest().getURI();
StringBuilder sb = new StringBuilder();
sb.append(uri.getScheme()).append("://").append(server).append(uri.getPath());
if (uri.getQuery() != null) {
String query = uri.getQuery();
int ticketLoc = query.indexOf("ticket=");
if (ticketLoc == -1) {
sb.append("?").append(query);
} else if (ticketLoc > 0) {
ticketLoc = query.indexOf("&ticket=");
if (ticketLoc == -1) {
sb.append("?").append(query);
} else if (ticketLoc > 0) {
sb.append("?").append(query, 0, ticketLoc);
}
}
}
String encodedService = URLEncoder.encode(sb.toString(), StandardCharsets.UTF_8);
if (log.isTraceEnabled()) {
log.trace("returning from getService() with encoded service [{}]", encodedService);
}
return encodedService;
} }
/** /**
@ -110,71 +178,58 @@ public class CASFilter implements WebFilter {
* *
* @param context 上下文 * @param context 上下文
* @return 结果 * @return 结果
* @throws IOException 异常
*/ */
private Mono<Void> redirectToCAS(CASContext context) { private Mono<Void> redirectToCAS(CASContext context) {
ServerHttpRequest request = context.getRequest(); ServerHttpRequest request = context.getRequest();
ServerHttpResponse response = context.getResponse();
String sessionId = context.getSession().getId(); String sessionId = context.getSession().getId();
log.trace("entering redirectToCAS()"); log.trace("entering redirectToCAS()");
String casLoginString = parameter.casLogin + "?service=" + this.getService(request) + StringBuilder casLoginString = new StringBuilder()
(parameter.casRenew ? "&renew=true" : "") + (parameter.casGateway ? "&gateway=true" : ""); .append(parameter.casLogin)
String sCookie; .append("?service=").append(this.getService(context))
sCookie = parameter.casServerName + request.getPath().contextPath().value(); .append(parameter.casRenew ? "&renew=true" : "")
casLoginString = casLoginString + "&appId=" + sCookie + "&sessionId=" + sessionId; .append(parameter.casGateway ? "&gateway=true" : "");
sCookie = request.getHeaders().getFirst("Cookie"); if (StringUtils.hasText(sessionId)) {
String cookie = null; String appId = parameter.casServerName + request.getPath().contextPath().value();
if (sCookie != null) { casLoginString.append("&appId=").append(appId)
String[] sCookies = sCookie.split(";"); .append("&sessionId=").append(sessionId);
for (int i = 0; i < sCookies.length; ++i) {
if (sCookies[i].indexOf("JSESSIONID=") != -1) {
cookie = sCookies[i].split("JSESSIONID=")[1];
}
}
} }
if (cookie != null && !cookie.equals("null") && !cookie.equals(request.getSession().getId())) { List<HttpCookie> cookies = request.getCookies().get("JSESSIONID");
casLoginString = casLoginString + "&timeOut=" + cookie; if (!CollectionUtils.isEmpty(cookies)) {
if (log.isDebugEnabled()) { cookies.stream()
log.debug("Session is timeout. The timeout session is {}", cookie); .filter(Objects::nonNull)
} .map(HttpCookie::getValue)
.filter(cookie -> !cookie.equals("null") && !cookie.equals(sessionId))
.peek(cookie -> log.debug("Session is timeout. The timeout session is {}", cookie))
.findFirst()
.ifPresent(cookie -> casLoginString.append("&timeOut=").append(cookie));
} }
if (log.isDebugEnabled()) { log.debug("Redirecting browser to [{})", casLoginString);
log.debug("Redirecting browser to [" + casLoginString + ")");
}
response.sendRedirect(casLoginString);
if (log.isTraceEnabled()) {
log.trace("returning from redirectToCAS()"); log.trace("returning from redirectToCAS()");
return context.redirect(casLoginString.toString());
} }
} private Mono<Void> redirectToInitFailure(CASContext context, String cause) {
private void redirectToInitFailure(HttpServletRequest request, HttpServletResponse response, String cause) throws IOException, ServletException {
log.trace("entering redirectToInitFailure()"); log.trace("entering redirectToInitFailure()");
String casLoginString = this.casLogin + "?action=initFailure"; String casLoginString = parameter.casLogin + "?action=initFailure";
if (cause != null && cause.equals("Illegal user")) { if (cause != null && cause.equals("Illegal user")) {
casLoginString = casLoginString + "&userIllegal=true"; casLoginString += "&userIllegal=true";
} }
String locale = request.getParameter("locale"); String locale = context.getQuery("locale");
if (locale != null) { if (locale != null) {
casLoginString = casLoginString + "&locale=" + locale; casLoginString += "&locale=" + locale;
} }
log.debug("Redirecting browser to [" + casLoginString + ")"); log.debug("Redirecting browser to [{})", casLoginString);
response.sendRedirect(casLoginString);
log.trace("returning from redirectToInitFailure()"); log.trace("returning from redirectToInitFailure()");
} return context.redirect(casLoginString);
public static SessionMappingStorage getSessionMappingStorage() {
return SESSION_MAPPING_STORAGE;
} }
private boolean isExclusion(String url) { private boolean isExclusion(String url) {
@ -186,11 +241,14 @@ public class CASFilter implements WebFilter {
} }
private Mono<Void> translate(CASContext context) { private Mono<Void> translate(CASContext context) {
// 是代理回调地址通过
if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath()) if (parameter.casProxyCallbackUrl != null && parameter.casProxyCallbackUrl.endsWith(context.getPath())
&& context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) { && context.getQuery("pgtId") != null && context.getQuery("pgtIou") != null) {
log.trace("passing through what we hope is CAS's request for proxy ticket receptor."); log.trace("passing through what we hope is CAS's request for proxy ticket receptor.");
return context.filter(); return context.filter();
} else { }
// 请求包装增强请求并完成自定义功能
if (parameter.wrapRequest) { if (parameter.wrapRequest) {
log.trace("Wrapping request with CASFilterRequestWrapper."); log.trace("Wrapping request with CASFilterRequestWrapper.");
// todo 暂时啥也不干看看有无问题 // todo 暂时啥也不干看看有无问题
@ -198,105 +256,99 @@ public class CASFilter implements WebFilter {
} }
WebSession session = context.getSession(); WebSession session = context.getSession();
Map<String, Object> sessionAttributes = session.getAttributes();
// 使用了用户标记快速跳过 // 使用了用户标记快速跳过
if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) { if (parameter.userLoginMark != null && session.getAttribute(parameter.userLoginMark) != null) {
return context.filter(); return context.filter();
} }
// 获取receipt // 获取receipt若存在则通过
CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT); CASReceipt receipt = session.getAttribute(CAS_FILTER_RECEIPT);
if (receipt != null && this.isReceiptAcceptable(receipt)) { if (receipt != null && this.isReceiptAcceptable(receipt)) {
log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter.."); log.trace("CAS_FILTER_RECEIPT attribute was present and acceptable - passing request through filter..");
return context.filter(); return context.filter();
} }
// 跳过请求
// 命中排除地址跳过请求
if (this.isExclusion(context.getPath())) { if (this.isExclusion(context.getPath())) {
return context.filter(); return context.filter();
} }
// 判断票据 // 判断票据
String ticket = context.getTicket(); String ticket = context.getTicket();
// 存在票据时验证票据
if (StringUtils.hasText(ticket)) { if (StringUtils.hasText(ticket)) {
try { try {
receipt = this.getAuthenticatedUser(context.getRequest(), ticket); receipt = this.getAuthenticatedUser(context);
} catch (CASAuthenticationException var22) { } catch (CASAuthenticationException e) {
return this.redirectToCAS(context); return this.redirectToCAS(context);
} }
if (!this.isReceiptAcceptable(receipt)) { if (!this.isReceiptAcceptable(receipt)) {
throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]"); throw new IllegalStateException("Authentication was technically successful but rejected as a matter of policy. [" + receipt + "]");
} else { }
// 记录receipt
String pt = context.getQuery("pt"); String pt = context.getQuery("pt");
if (StringUtils.hasText(pt)) { if (StringUtils.hasText(pt)) {
session.getAttributes().put(pt, receipt); context.setSessionAttribute(pt, receipt);
} }
// 获取到用户名
String userName = receipt.getUserName(); String userName = receipt.getUserName();
if (StringUtils.hasText(parameter.casInitContextClass)) { // 尝试初始化
if (null != initializer) {
try { try {
Class<?> cls = Class.forName(parameter.casInitContextClass); String translated = initializer.getTranslatorUser(userName);
Object obj = cls.getConstructor().newInstance(); log.debug("translated username: {} to {}", userName, translated);
if (obj instanceof IContextInit) { initializer.initContext(context, translated);
Method translatorMethod = cls.getMethod("getTranslatorUser", String.class); } catch (Exception e) {
userName = (String) translatorMethod.invoke(obj, userName);
Method initContextMethod = cls.getMethod("initContext", ServletRequest.class, ServletResponse.class, FilterChain.class, String.class);
initContextMethod.invoke(obj, request, response, fc, userName);
}
} catch (ClassNotFoundException | IllegalArgumentException | IllegalAccessException |
InstantiationException | SecurityException | NoSuchMethodException e) {
e.printStackTrace();
} catch (InvocationTargetException var19) {
InvocationTargetException e = var19;
String cause = e.getCause().getMessage(); String cause = e.getCause().getMessage();
session.setAttribute("initFailure", cause); context.setSessionAttribute("initFailure", cause);
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); return this.redirectToInitFailure(context, cause);
e.printStackTrace();
return Mono.empty();
} }
} }
session.setAttribute("edu.yale.its.tp.cas.client.filter.user", userName); sessionAttributes.put(CAS_FILTER_USER, userName);
session.setAttribute("edu.yale.its.tp.cas.client.filter.receipt", receipt); sessionAttributes.put(CAS_FILTER_RECEIPT, receipt);
session.removeAttribute(CAS_FILTER_GATEWAYED); sessionAttributes.remove(CAS_FILTER_GATEWAYED);
if (log.isTraceEnabled()) {
log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt); log.trace("validated ticket to get authenticated receipt [{}], now passing request along filter chain.", receipt);
log.trace("returning from doFilter()"); log.trace("returning from doFilter()");
}
return context.filter(); return context.filter();
} }
return Mono.empty();
} else { // 不存在票据跳转验证
log.trace("CAS ticket was not present on request."); log.trace("CAS ticket was not present on request.");
boolean didGateway = Boolean.valueOf((String) session.getAttribute(CAS_FILTER_GATEWAYED)); boolean didGateway = Boolean.parseBoolean(session.getAttribute(CAS_FILTER_GATEWAYED));
if (parameter.casLogin == null) { if (parameter.casLogin == null) {
log.error("casLogin was not set, so filter cannot redirect request for authentication."); log.error("casLogin was not set, so filter cannot redirect request for authentication.");
throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter"); throw new IllegalArgumentException("When CASFilter protects pages that do not receive a 'ticket' parameter, it needs a edu.yale.its.tp.cas.client.filter.loginUrl filter parameter");
} else if (!didGateway) { }
if (!didGateway) {
log.trace("Did not previously gateway. Setting session attribute to true."); log.trace("Did not previously gateway. Setting session attribute to true.");
((HttpServletRequest) request).setAttribute("sessionId", session.getId()); sessionAttributes.put(CAS_FILTER_GATEWAYED, "true");
session.setAttribute(CAS_FILTER_GATEWAYED, "true"); return this.redirectToCAS(context);
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); }
} else {
log.trace("Previously gatewayed."); log.trace("Previously gatewayed.");
if (!this.casGateway && session.getAttribute("edu.yale.its.tp.cas.client.filter.user") == null) { if (!parameter.casGateway && session.getAttribute(CAS_FILTER_USER) == null) {
if (session.getAttribute("initFailure") != null) { if (session.getAttribute("initFailure") != null) {
String cause = (String) session.getAttribute("initFailure"); String cause = session.getAttribute("initFailure");
this.redirectToInitFailure((HttpServletRequest) request, (HttpServletResponse) response, cause); return this.redirectToInitFailure(context, cause);
} else { }
((HttpServletRequest) request).setAttribute("sessionId", session.getId());
session.setAttribute(CAS_FILTER_GATEWAYED, "true"); sessionAttributes.put(CAS_FILTER_GATEWAYED, "true");
this.redirectToCAS((HttpServletRequest) request, (HttpServletResponse) response); return this.redirectToCAS(context);
} }
} else {
log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain."); log.trace("casGateway was true and CAS_FILTER_USER set: passing request along filter chain.");
fc.doFilter((ServletRequest) request, response); return context.filter();
}
}
}
}
} }
/** /**
@ -306,19 +358,30 @@ public class CASFilter implements WebFilter {
* @return 结果 * @return 结果
*/ */
private Mono<Void> handle(CASContext context) { private Mono<Void> handle(CASContext context) {
// 优先处理token请求
if (context.isTokenRequest()) {
String sessionId = context.getSession().getId();
log.debug("Storing session identifier for {}", sessionId);
// 包括ticket尝试重新替换session
return sessionMappingStorage.removeBySessionById(sessionId)
.onErrorContinue((e, v) -> log.debug("error when remove session"))
.then(Mono.defer(() -> sessionMappingStorage.addSessionById(context.getTicket(), context.getSession())
.then(translate(context))));
}
// post请求需要特殊处理 // post请求需要特殊处理
if (context.getMethod() == HttpMethod.POST) { if (context.getMethod() == HttpMethod.POST) {
// 此处可能要求安全的获取参数单独针对退出请求 // 通过form表单获取注销请求处理注销逻辑
return context.getFormData("logoutRequest") return context.getFormData("logoutRequest")
.doOnNext(payload -> log.trace("Logout request=[{}]", payload)) .doOnNext(payload -> log.trace("Logout request=[{}]", payload))
.defaultIfEmpty("") .defaultIfEmpty("")
.flatMap(payload -> { .flatMap(payload -> {
if (StringUtils.hasText(payload)) { if (StringUtils.hasText(payload)) {
String sessionIdentifier = XmlUtils.getTextForElement(payload, "SessionIndex"); String token = XmlUtils.getTextForElement(payload, "SessionIndex");
if (StringUtils.hasText(sessionIdentifier)) { if (StringUtils.hasText(token)) {
// 满足条件时断路 // 满足条件时断路
return SESSION_MAPPING_STORAGE.removeSessionByMappingId(sessionIdentifier) return sessionMappingStorage.removeSessionByMappingId(token)
.doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), sessionIdentifier)) .doOnNext(session -> log.debug("Invalidating session [{}] for ST [{}]", session.getId(), token))
.flatMap(WebSession::invalidate) .flatMap(WebSession::invalidate)
.doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e)) .doOnError(IllegalStateException.class, e -> log.debug(e.getMessage(), e))
.onErrorComplete(); .onErrorComplete();
@ -327,21 +390,9 @@ public class CASFilter implements WebFilter {
// 继续执行 // 继续执行
return translate(context); return translate(context);
}); });
} else {
String ticket = context.getTicket();
String sessionId = context.getSession().getId();
log.debug("Storing session identifier for {}", sessionId);
// 包括ticket尝试重新替换session
if (StringUtils.hasText(ticket)) {
return SESSION_MAPPING_STORAGE.removeBySessionById(sessionId)
.onErrorResume(e -> SESSION_MAPPING_STORAGE.addSessionById(ticket, context.getSession()))
.then(Mono.defer(() -> translate(context)));
} }
return translate(context); return translate(context);
} }
}
@Override @Override
@ -353,7 +404,7 @@ public class CASFilter implements WebFilter {
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
log.trace("entering doFilter()"); log.trace("entering doFilter()");
} }
// 执行中断策略 // 执行跳过策略
String pt = context.getQuery("pt"); String pt = context.getQuery("pt");
if (StringUtils.hasText(pt)) { if (StringUtils.hasText(pt)) {
if (session.getAttribute(pt) != null) { if (session.getAttribute(pt) != null) {

View File

@ -0,0 +1,20 @@
package dev.flyfish.boot.cas.filter;
/**
* 登录过滤器旨在缓存用户名
*
* @author wangyu
*/
public class CASLoginFilter implements CASContextInit {
public static String CONST_CAS_USERNAME = "const_cas_username";
@Override
public String getTranslatorUser(String username) {
return username;
}
@Override
public void initContext(CASContext casContext, String username) {
casContext.setSessionAttribute(CONST_CAS_USERNAME, username);
}
}

View File

@ -31,7 +31,7 @@ public class CASParameter {
String casProxyCallbackUrl; String casProxyCallbackUrl;
@JsonAlias(CASFilter.CAS_FILTER_INITCONTEXTCLASS) @JsonAlias(CASFilter.CAS_FILTER_INITCONTEXTCLASS)
String casInitContextClass; Class<?> casInitContextClass;
@JsonAlias(CASFilter.RENEW_INIT_PARAM) @JsonAlias(CASFilter.RENEW_INIT_PARAM)
boolean casRenew; boolean casRenew;
@ -63,7 +63,7 @@ public class CASParameter {
/** /**
* 检查配置参数是否有误 * 检查配置参数是否有误
*/ */
public void check() { public CASParameter checked() {
if (this.casGateway && this.casRenew) { if (this.casGateway && this.casRenew) {
throw new IllegalArgumentException("gateway and renew cannot both be true in filter configuration"); throw new IllegalArgumentException("gateway and renew cannot both be true in filter configuration");
} else if (this.casServerName != null && this.casServiceUrl != null) { } else if (this.casServerName != null && this.casServiceUrl != null) {
@ -73,5 +73,6 @@ public class CASParameter {
} else if (this.casValidate == null) { } else if (this.casValidate == null) {
throw new IllegalArgumentException("validateUrl parameter must be set."); throw new IllegalArgumentException("validateUrl parameter must be set.");
} }
return this;
} }
} }

View File

@ -1 +0,0 @@
spring.application.name=cas

View File

@ -0,0 +1,13 @@
server:
port: 8080
spring:
application:
name: cas
cas:
filter:
cas-login: https://sdsfzt.sxu.edu.cn/authserver/login
cas-validate: https://sdsfzt.sxu.edu.cn/authserver/serviceValidate
cas-server-name: 127.0.0.1:8080
cas-init-context-class: dev.flyfish.boot.cas.filter.CASLoginFilter