import {useMutation, useQueryClient} from "@tanstack/react-query";
import {useQuery} from "@tanstack/react-query";
import {QueryResult} from "./types";
import {DefaultService, AIModel} from "./generated_client";
import {useRequireToken} from "./auth";

const BASE_QUERY_KEY = "models";

function getQueryKey(modelId:number|undefined) {
    return [BASE_QUERY_KEY, modelId];
}

export function useModelMutation(modelId:number|null) {
    useRequireToken();
    const queryClient = useQueryClient();

    return useMutation({
        mutationFn: async (model:AIModel) => {
            let response: AIModel;
            if(modelId === null) {
                response = await DefaultService.createAiModelModelsPost(model);
            } else {
                response = await DefaultService.updateAiModelModelsAiModelIdPost(modelId, model);
            }
            await queryClient.invalidateQueries({
                queryKey: getQueryKey(response.id)
            });
            queryClient.setQueryData(getQueryKey(response.id), response);
            return response;
        }
    });
}

export function useModelQuery(modelId:number):QueryResult<AIModel> {
    useRequireToken();

    return useQuery({
        queryKey: getQueryKey(modelId),
        queryFn: async () => {
            const response = await DefaultService.getAiModelModelsAiModelIdGet(modelId);
            return response;
        },
    });
}

export function useModelListQuery():QueryResult<AIModel[]> {
    useRequireToken();

    return useQuery({
        queryKey: [BASE_QUERY_KEY],
        queryFn: async () => {
            const response = await DefaultService.listAiModelsModelsGet();
            return response;
        },
    });
}

