/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.mcp.server;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.enums.RpcTypeEnum;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.context.ShenyuContext;
import org.apache.shenyu.plugin.api.utils.RequestUrlUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.mcp.server.handler.McpServerPluginDataHandler;
import org.apache.shenyu.plugin.mcp.server.holder.ShenyuMcpExchangeHolder;
import org.apache.shenyu.plugin.mcp.server.manager.ShenyuMcpServerManager;
import org.apache.shenyu.plugin.mcp.server.model.ShenyuMcpServer;
import org.apache.shenyu.plugin.mcp.server.transport.MessageHandlingResult;
import org.apache.shenyu.plugin.mcp.server.transport.ShenyuSseServerTransportProvider;
import org.apache.shenyu.plugin.mcp.server.transport.ShenyuStreamableHttpServerTransportProvider;
import org.apache.shenyu.plugin.mcp.server.transport.SseEventFormatter;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

public class McpServerPlugin
extends AbstractShenyuPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(McpServerPlugin.class);
    private static final String MESSAGE_ENDPOINT = "/message";
    private static final String STREAMABLE_HTTP_PATH = "/streamablehttp";
    private static final String SSE_PATH = "/sse";
    private static final String MCP_TOOL_CALL_ATTR = "MCP_TOOL_CALL";
    private static final String MCP_SESSION_ID_ATTR = "MCP_SESSION_ID";
    private static final String SESSION_ID_PARAM = "sessionId";
    private static final String[] SESSION_ID_HEADERS = new String[]{"X-Session-Id", "Mcp-Session-Id"};
    private static final String AUTHORIZATION_HEADER = "Authorization";
    private static final String BEARER_PREFIX = "Bearer ";
    private final ShenyuMcpServerManager shenyuMcpServerManager;
    private final List<HttpMessageReader<?>> messageReaders;

    public McpServerPlugin(ShenyuMcpServerManager shenyuMcpServerManager, List<HttpMessageReader<?>> messageReaders) {
        this.shenyuMcpServerManager = shenyuMcpServerManager;
        this.messageReaders = messageReaders;
    }

    protected String getRawPath(ServerWebExchange exchange) {
        return RequestUrlUtils.getRewrittenRawPath((ServerWebExchange)exchange);
    }

    protected Mono<Void> doExecute(ServerWebExchange exchange, ShenyuPluginChain chain, SelectorData selector, RuleData rule) {
        ShenyuContext shenyuContext = (ShenyuContext)exchange.getAttribute("context");
        Objects.requireNonNull(shenyuContext, "ShenyuContext must not be null");
        String uri = exchange.getRequest().getURI().getRawPath();
        LOG.debug("Processing MCP request with URI: {}", (Object)uri);
        if (!this.shenyuMcpServerManager.canRoute(uri)) {
            LOG.debug("URI not handled by MCP server, continuing chain: {}", (Object)uri);
            return chain.execute(exchange);
        }
        LOG.debug("Handling MCP request for URI: {}", (Object)uri);
        ServerRequest request = ServerRequest.create((ServerWebExchange)exchange, this.messageReaders);
        return this.routeByProtocol(exchange, chain, request, selector, uri);
    }

    public String named() {
        return PluginEnum.MCP_SERVER.getName();
    }

    public boolean skip(ServerWebExchange exchange) {
        Boolean isMcpToolCall = (Boolean)exchange.getAttribute(MCP_TOOL_CALL_ATTR);
        if (Boolean.TRUE.equals(isMcpToolCall)) {
            LOG.debug("Skipping MCP plugin for tool call to prevent infinite loop");
            return true;
        }
        return this.skipExcept(exchange, new RpcTypeEnum[]{RpcTypeEnum.HTTP});
    }

    public int getOrder() {
        return PluginEnum.MCP_SERVER.getCode();
    }

    private Mono<Void> routeByProtocol(ServerWebExchange exchange, ShenyuPluginChain chain, ServerRequest request, SelectorData selector, String uri) {
        if (this.isStreamableHttpProtocol(uri)) {
            return this.handleStreamableHttpRequest(exchange, chain, request, uri);
        }
        if (this.isSseProtocol(uri)) {
            return this.handleSseRequest(exchange, chain, request, selector, uri);
        }
        LOG.debug("Using default SSE protocol for URI: {}", (Object)uri);
        return this.handleSseRequest(exchange, chain, request, selector, uri);
    }

    private String extractSessionId(ServerWebExchange exchange) {
        String sessionId = (String)exchange.getRequest().getQueryParams().getFirst((Object)SESSION_ID_PARAM);
        if (Objects.nonNull(sessionId)) {
            LOG.debug("Found sessionId in query parameters: {}", (Object)sessionId);
            return sessionId;
        }
        for (String headerName : SESSION_ID_HEADERS) {
            sessionId = exchange.getRequest().getHeaders().getFirst(headerName);
            if (!Objects.nonNull(sessionId)) continue;
            LOG.debug("Found sessionId in {} header: {}", (Object)headerName, (Object)sessionId);
            return sessionId;
        }
        String authHeader = exchange.getRequest().getHeaders().getFirst(AUTHORIZATION_HEADER);
        if (Objects.nonNull(authHeader) && authHeader.startsWith(BEARER_PREFIX)) {
            sessionId = authHeader.substring(BEARER_PREFIX.length());
            LOG.debug("Found sessionId in Authorization header: {}", (Object)sessionId);
            return sessionId;
        }
        LOG.debug("No sessionId found in request for path: {}", (Object)exchange.getRequest().getPath().value());
        return null;
    }

    private boolean isStreamableHttpProtocol(String uri) {
        return uri.contains(STREAMABLE_HTTP_PATH) || uri.endsWith(STREAMABLE_HTTP_PATH);
    }

    private boolean isSseProtocol(String uri) {
        return uri.contains(SSE_PATH) || uri.endsWith(SSE_PATH) || uri.endsWith(MESSAGE_ENDPOINT);
    }

    private Mono<Void> handleStreamableHttpRequest(ServerWebExchange exchange, ShenyuPluginChain chain, ServerRequest request, String uri) {
        LOG.debug("Handling Streamable HTTP MCP request for URI: {}", (Object)uri);
        ShenyuStreamableHttpServerTransportProvider transportProvider = this.shenyuMcpServerManager.getOrCreateStreamableHttpTransport(uri);
        this.setupSessionContext(exchange, chain);
        return this.processStreamableHttpEndpoint(exchange, transportProvider, request);
    }

    private Mono<Void> handleSseRequest(ServerWebExchange exchange, ShenyuPluginChain chain, ServerRequest request, SelectorData selector, String uri) {
        LOG.debug("Handling SSE MCP request for URI: {}", (Object)uri);
        ShenyuMcpServer server = (ShenyuMcpServer)McpServerPluginDataHandler.CACHED_SERVER.get().obtainHandle((Object)selector.getId());
        if (Objects.isNull(server)) {
            return chain.execute(exchange);
        }
        String messageEndpoint = server.getMessageEndpoint();
        ShenyuSseServerTransportProvider transportProvider = this.shenyuMcpServerManager.getOrCreateMcpServerTransport(uri, messageEndpoint);
        if (uri.endsWith(messageEndpoint)) {
            this.setupSessionContext(exchange, chain);
            return this.handleMessageEndpoint(exchange, transportProvider, request);
        }
        return this.handleSseEndpoint(exchange, transportProvider, request);
    }

    private void setupSessionContext(ServerWebExchange exchange, ShenyuPluginChain chain) {
        String sessionId = this.extractSessionId(exchange);
        if (Objects.nonNull(sessionId)) {
            exchange.getAttributes().put(MCP_SESSION_ID_ATTR, sessionId);
            exchange.getAttributes().put("chain", chain);
            ShenyuMcpExchangeHolder.put(sessionId, exchange);
            LOG.debug("Set up session context for sessionId: {}", (Object)sessionId);
        }
    }

    private Mono<Void> processStreamableHttpEndpoint(ServerWebExchange exchange, ShenyuStreamableHttpServerTransportProvider transportProvider, ServerRequest request) {
        LOG.debug("Processing Streamable HTTP endpoint for request: {}", (Object)request.path());
        String method = exchange.getRequest().getMethod().name();
        if ("GET".equalsIgnoreCase(method)) {
            return this.handleStreamableHttpGetRequest(exchange);
        }
        if ("POST".equalsIgnoreCase(method)) {
            return this.handleStreamableHttpPostRequest(exchange, transportProvider, request);
        }
        return this.handleUnsupportedMethod(exchange);
    }

    private Mono<Void> handleStreamableHttpGetRequest(ServerWebExchange exchange) {
        LOG.debug("Rejecting Streamable HTTP GET request (protocol does not support GET)");
        this.setErrorResponse(exchange, HttpStatus.METHOD_NOT_ALLOWED, "POST, OPTIONS", this.createJsonError(-32601, "Streamable HTTP does not support GET requests. Please use POST requests for all MCP operations."));
        return this.writeJsonResponse(exchange);
    }

    private Mono<Void> handleStreamableHttpPostRequest(ServerWebExchange exchange, ShenyuStreamableHttpServerTransportProvider transportProvider, ServerRequest request) {
        LOG.debug("Processing Streamable HTTP POST request for message handling");
        return transportProvider.handleMessageEndpoint(exchange, request).flatMap(result -> this.processStreamableHttpResult(exchange, (MessageHandlingResult)result)).doOnSuccess(aVoid -> LOG.debug("Streamable HTTP message processing completed")).doOnError(error -> LOG.error("Error in Streamable HTTP message processing: {}", (Object)error.getMessage(), error));
    }

    private Mono<Void> processStreamableHttpResult(ServerWebExchange exchange, MessageHandlingResult result) {
        LOG.debug("Processing Streamable HTTP result - Status: {}, SessionId: {}", (Object)result.getStatusCode(), (Object)result.getSessionId());
        this.configureStreamableHttpResponse(exchange, result);
        String responseBodyJson = result.getResponseBodyAsJson();
        byte[] responseBytes = responseBodyJson.getBytes(StandardCharsets.UTF_8);
        LOG.debug("Writing response body with {} bytes", (Object)responseBytes.length);
        return exchange.getResponse().writeWith((Publisher)Mono.just((Object)exchange.getResponse().bufferFactory().wrap(responseBytes))).doOnSuccess(aVoid -> LOG.debug("Response transmission completed successfully")).doOnError(error -> LOG.error("Error writing response: {}", (Object)error.getMessage(), error));
    }

    private void configureStreamableHttpResponse(ServerWebExchange exchange, MessageHandlingResult result) {
        exchange.getResponse().setStatusCode((HttpStatusCode)HttpStatus.valueOf((int)result.getStatusCode()));
        if (Objects.nonNull(result.getSessionId())) {
            exchange.getResponse().getHeaders().set("Mcp-Session-Id", result.getSessionId());
        } else {
            String sessionId = this.extractSessionId(exchange);
            if (Objects.nonNull(sessionId)) {
                exchange.getResponse().getHeaders().set("Mcp-Session-Id", sessionId);
            }
        }
        this.setCorsHeaders(exchange);
        exchange.getResponse().getHeaders().set("Content-Type", "application/json");
        exchange.getResponse().getHeaders().remove((Object)"Transfer-Encoding");
        exchange.getResponse().getHeaders().remove((Object)"Content-Length");
        LOG.debug("Configured Streamable HTTP response headers");
    }

    private Mono<Void> handleUnsupportedMethod(ServerWebExchange exchange) {
        LOG.debug("Unsupported HTTP method: {}", (Object)exchange.getRequest().getMethod());
        this.setErrorResponse(exchange, HttpStatus.BAD_REQUEST, null, this.createJsonError(-32600, "Unsupported HTTP method"));
        return this.writeJsonResponse(exchange);
    }

    private Mono<Void> handleSseEndpoint(ServerWebExchange exchange, ShenyuSseServerTransportProvider transportProvider, ServerRequest request) {
        LOG.debug("Setting up SSE endpoint for request: {}", (Object)request.path());
        this.configureSseHeaders(exchange);
        return exchange.getResponse().writeWith((Publisher)transportProvider.createSseFlux(request).doOnNext(event -> {
            String eventType = event.event();
            LOG.debug("SSE Event - Type: {}", (Object)(Objects.isNull(eventType) ? "data" : eventType));
        }).map(event -> SseEventFormatter.formatEvent(event, exchange)).doOnSubscribe(subscription -> LOG.debug("SSE stream subscribed")).doOnComplete(() -> LOG.debug("SSE stream completed")).doOnError(error -> LOG.error("SSE stream error: {}", (Object)error.getMessage(), error)));
    }

    private void configureSseHeaders(ServerWebExchange exchange) {
        exchange.getResponse().getHeaders().set("Content-Type", "text/event-stream");
        exchange.getResponse().getHeaders().set("Cache-Control", "no-cache");
        exchange.getResponse().getHeaders().set("Connection", "keep-alive");
        this.setCorsHeaders(exchange);
        LOG.debug("Configured SSE headers");
    }

    private Mono<Void> handleMessageEndpoint(ServerWebExchange exchange, ShenyuSseServerTransportProvider transportProvider, ServerRequest request) {
        LOG.debug("Processing message endpoint request");
        return transportProvider.handleMessageEndpoint(request).flatMap(result -> {
            LOG.debug("Message handling result - Status: {}, Body length: {} chars", (Object)result.getStatusCode(), (Object)(Objects.nonNull(result.getResponseBody()) ? result.getResponseBody().length() : 0));
            exchange.getResponse().setStatusCode((HttpStatusCode)HttpStatus.valueOf((int)result.getStatusCode()));
            exchange.getResponse().getHeaders().add("Content-Type", "application/json");
            this.setCorsHeaders(exchange);
            String responseBody = String.format("{\"message\":\"%s\"}", result.getResponseBody());
            LOG.debug("Sending message response with length: {} chars", (Object)responseBody.length());
            return exchange.getResponse().writeWith((Publisher)Mono.just((Object)exchange.getResponse().bufferFactory().wrap(responseBody.getBytes())));
        }).doOnSuccess(aVoid -> LOG.debug("Message response completed")).doOnError(error -> LOG.error("Error in message response: {}", (Object)error.getMessage(), error));
    }

    private void setCorsHeaders(ServerWebExchange exchange) {
        exchange.getResponse().getHeaders().set("Access-Control-Allow-Origin", "*");
        exchange.getResponse().getHeaders().set("Access-Control-Allow-Headers", "Content-Type, Mcp-Session-Id, Authorization, Last-Event-ID");
        exchange.getResponse().getHeaders().set("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
    }

    private void setErrorResponse(ServerWebExchange exchange, HttpStatus status, String allowHeader, Map<String, Object> errorBody) {
        exchange.getResponse().setStatusCode((HttpStatusCode)status);
        exchange.getResponse().getHeaders().add("Content-Type", "application/json");
        this.setCorsHeaders(exchange);
        if (Objects.nonNull(allowHeader)) {
            exchange.getResponse().getHeaders().add("Allow", allowHeader);
        }
        exchange.getAttributes().put("errorBody", errorBody);
    }

    private Mono<Void> writeJsonResponse(ServerWebExchange exchange) {
        Map errorBody = (Map)exchange.getAttributes().get("errorBody");
        if (Objects.isNull(errorBody)) {
            return Mono.empty();
        }
        try {
            String errorResponse = new ObjectMapper().writeValueAsString((Object)errorBody);
            return exchange.getResponse().writeWith((Publisher)Mono.just((Object)exchange.getResponse().bufferFactory().wrap(errorResponse.getBytes())));
        }
        catch (Exception e) {
            LOG.error("Error writing JSON response: {}", (Object)e.getMessage(), (Object)e);
            return Mono.empty();
        }
    }

    private Map<String, Object> createJsonError(int code, String message) {
        return Map.of("error", Map.of("code", code, "message", message));
    }
}

