import { type Client, createClient } from "graphql-ws";
import { remove } from "lodash";
import {
  RelayNetworkLayer,
  RelayNetworkLayerRequest
} from "react-relay-network-modern/es";
import {
  type CacheConfig,
  type FetchFunction,
  Network,
  Observable,
  type RequestParameters,
  type SubscribeFunction,
  type UploadableMap,
  type Variables
} from "relay-runtime";
import JSWebSocket from "ws";
import { logger as baseLogger } from "../../logger";
import type { Subscriber } from "../../types";
import { generateMiddleware } from "./http/generate-middleware";
import type { ProtocolOptions } from "./types";

const logger = baseLogger.createLogger("GraphQLWS");

const WebSocketImpl =
  typeof WebSocket === "undefined" ? JSWebSocket : WebSocket;

export async function makeClient({
  graphqlSocket,
  getTokenFn,
  setTokenFn,
  onSocketError,
  requestHeaders = {},
  onSocketChange
}: ProtocolOptions) {
  const client = createClient({
    url: graphqlSocket,
    webSocketImpl: WebSocketImpl,
    connectionParams: async () => {
      const token = await getTokenFn?.();
      return {
        ...(typeof requestHeaders === "function"
          ? await requestHeaders()
          : requestHeaders),
        ...(token ? { Authorization: `Bearer ${token}` } : {})
      };
    },
    keepAlive: 10_000,
    on: {
      closed: () => onSocketChange?.("disconnected"),
      opened: () => onSocketChange?.("opened"),
      connected: (_socket, params) => {
        logger.info("socket connected", params);
        const nextValidToken = params?.["next-valid-token"];
        if (nextValidToken) setTokenFn?.(String(nextValidToken));
        onSocketChange?.("connected");
      },
      error: () => {
        onSocketError?.(new TypeError("Websocket errored"));
      }
    }
  });

  return client;
}

function generateSwitchHandler(
  nonUploadHandler: FetchFunction,
  uploadNetwork: RelayNetworkLayer
): FetchFunction {
  return (
    operation: RequestParameters,
    variables: Variables,
    cacheConfig: CacheConfig,
    uploadables?: UploadableMap | null
  ) =>
    uploadables
      ? uploadNetwork.execute(operation, variables, cacheConfig, uploadables)
      : nonUploadHandler(operation, variables, cacheConfig);
}

export async function generateGraphqlWsHandler(
  client: Client,
  { onRequest, onResponse, extensions: extensionsThunk }: ProtocolOptions
) {
  let isDraining = false;
  const activeRequests = new Set<symbol>();
  const trackActiveRequest = (operationName: string) => {
    const id = Symbol(operationName);
    const size = activeRequests.size;
    activeRequests.add(id);
    logger.debug(`trackActiveRequest (${size}➡️${size + 1})`, id);
    return id;
  };

  const untrackActiveRequest = (id: symbol) => {
    const size = activeRequests.size;
    activeRequests.delete(id);
    logger.debug(`untrackActiveRequest (${size}➡️${size - 1})`, id);
    handleDraining();
  };

  const handleDraining = (nextIsDraining = isDraining) => {
    isDraining = nextIsDraining;
    if (!isDraining) return;
    if (activeRequests.size > 0) {
      logger.info(
        `delaying drain due to ${activeRequests.size} active requests`
      );
      return;
    }
    logger.info("draining");
    client.dispose();
  };

  client.on("pong", (received, payload) => {
    if (
      received &&
      payload?.type === "server_control" &&
      payload?.action === "draining"
    ) {
      handleDraining(true);
    }
  });

  return (
    operation: RequestParameters,
    variables: Variables,
    cacheConfig: CacheConfig,
    uploadables?: UploadableMap | null
  ) => {
    const isSubscription = operation.operationKind === "subscription";
    const requestId = isSubscription
      ? undefined
      : trackActiveRequest?.(operation.name);
    onRequest?.(
      new RelayNetworkLayerRequest(
        operation as any,
        variables,
        cacheConfig as any,
        uploadables ?? null
      )
    );
    return Observable.create<any>((sink) => {
      const extensions =
        typeof extensionsThunk === "function"
          ? extensionsThunk()
          : extensionsThunk;
      const operationName = operation.name;
      const query = operation.text ?? operation.id ?? "";
      const customSink: typeof sink = {
        next: (data) => {
          onResponse?.(data);
          logger.info(operationName, "sink next", data);
          sink.next(data);
        },
        error: (error) => {
          if (error instanceof Event) {
            if (error instanceof CloseEvent) {
              error = new Error("Connection closed");
            } else {
              error = new TypeError("Failed to fetch");
            }
          }
          if (requestId) untrackActiveRequest(requestId);
          logger.error(operationName, "sink error", error);
          sink.error(error);
        },
        complete: () => {
          if (requestId) untrackActiveRequest(requestId);
          logger.info(operationName, "sink complete");
          sink.complete();
        },
        get closed() {
          return sink.closed;
        }
      };
      return client.subscribe(
        {
          operationName,
          query,
          variables,
          extensions: {
            ...extensions,
            ...(operation.id ? { persistedQuery: operation.id } : {})
          }
        },
        customSink
      );
    });
  };
}

export async function generateGraphqlWsSubscriber(
  options: ProtocolOptions,
  client: Client | Promise<Client> = makeClient(options)
) {
  const subscribers: Subscriber[] = [];
  const clientInstance = await client;
  let activeConnection = false;

  clientInstance.on("connected", () => {
    activeConnection = true;
    logger.info("socket opened");
    subscribers.forEach((subscriber) => subscriber(true));
  });

  clientInstance.on("closed", () => (activeConnection = false));

  const handler = (await generateGraphqlWsHandler(
    clientInstance,
    options
  )) as SubscribeFunction;

  // Register a subscriber
  const registerSubscriber = (subscriber: Subscriber) => {
    subscribers.push(subscriber);
    if (activeConnection) subscriber(true);
  };

  // Deregister a subscriber
  const deregisterSubscriber = (subscriber: Subscriber) => {
    remove(subscribers, (registered) => registered === subscriber);
  };

  return {
    handler,
    registerSubscriber,
    deregisterSubscriber
  };
}

export async function generateGraphqlWsNetwork(options: ProtocolOptions) {
  const client = await makeClient(options);
  const fetchHandler = await generateGraphqlWsHandler(client, options);
  const httpNetwork = new RelayNetworkLayer(await generateMiddleware(options), {
    noThrow: true
  });
  const switchHandler = generateSwitchHandler(fetchHandler, httpNetwork);
  const {
    handler: subscriptionHandler,
    registerSubscriber,
    deregisterSubscriber
  } = await generateGraphqlWsSubscriber(options, client);
  const network = Network.create(switchHandler, subscriptionHandler);
  return {
    network,
    dispose: client.dispose,
    registerSubscriber,
    deregisterSubscriber
  };
}
