#include "util.h"
#include <ntstrsafe.h>
#include "init.h"
#include "except.h"

const UNICODE_STRING g_DosDevicesPrefix = RTL_CONSTANT_STRING(L"\\DosDevices\\");
const UNICODE_STRING g_VolumeGuidPrefix = RTL_CONSTANT_STRING(L"\\??\\Volume{");
const UNICODE_STRING g_ObjectManagerPrefix = RTL_CONSTANT_STRING(L"\\??\\");

void
PTFSFcbLock(
	IN PPTFSFCB	pFcb,
    IN BOOLEAN  bReadOnly
)
{
    if (bReadOnly)
    {
        KeEnterCriticalRegion();
        ExAcquireResourceSharedLite(pFcb->AdvancedFCBHeader.Resource, TRUE);
    }
    else
        ExEnterCriticalRegionAndAcquireResourceExclusive(pFcb->AdvancedFCBHeader.Resource);
}

void
PTFSFcbUnlock(
	IN PPTFSFCB	pFcb
)
{
	ExReleaseResourceAndLeaveCriticalRegion(pFcb->AdvancedFCBHeader.Resource);
}

void
PTFSVcbLock(
    IN PPTFSVCB pVcb,
    IN BOOLEAN bReadOnly
)
{
    if (bReadOnly)
    {
        KeEnterCriticalRegion();
        ExAcquireResourceSharedLite(&pVcb->Resource, TRUE);
    }
    else
        ExEnterCriticalRegionAndAcquireResourceExclusive(&pVcb->Resource);
}

void
PTFSVcbUnlock(
    IN PPTFSVCB pVcb
) 
{
    ExReleaseResourceAndLeaveCriticalRegion(&pVcb->Resource);
}

void
PTFSPagingIoLock(
    IN PPTFSFCB pFcb,
    IN BOOLEAN bReadOnly
)
{
    if (bReadOnly)
    {
        KeEnterCriticalRegion();
        ExAcquireResourceSharedLite(&pFcb->PagingIoResource, TRUE);
    }
    else
        ExEnterCriticalRegionAndAcquireResourceExclusive(&pFcb->PagingIoResource);
}

void
PTFSPagingIoUnlock(
    IN PPTFSFCB pFcb
)
{
    ExReleaseResourceAndLeaveCriticalRegion(&pFcb->PagingIoResource);
}


VOID 
PTFSFreeUnicodeString(
	IN PUNICODE_STRING pUnicodeString
) 
{
	if (pUnicodeString != NULL) 
	{
		PTFSFree(pUnicodeString->Buffer);
		PTFSFree(pUnicodeString);
	}
}

VOID 
PTFSCompleteIrpRequest(
    IN PIRP pIrp,
    IN NTSTATUS status,
    IN ULONG_PTR Info
) 
{
    if (pIrp == NULL)
        return;

    if (status == -1)
        status = STATUS_INVALID_PARAMETER;

    if (status != STATUS_PENDING) 
    {
        pIrp->IoStatus.Status = status;
        pIrp->IoStatus.Information = Info;
        IoCompleteRequest(pIrp, IO_NO_INCREMENT);
    }
}

VOID 
PTFSCompleteDispatchRoutine(
    IN PIRP pIrp,
    IN NTSTATUS status
) 
{
    if (pIrp == NULL)
        return;

    if (status == -1)
        status = STATUS_INVALID_PARAMETER;

    if (status != STATUS_PENDING)
    {
        pIrp->IoStatus.Status = status;
        IoCompleteRequest(pIrp, IO_NO_INCREMENT);
    }
}

PVOID 
PTFSAllocateZero(
    IN SIZE_T size) 
{
    PVOID pbuffer = PTFSAllocate(size);

    if (pbuffer) 
        RtlZeroMemory(pbuffer, size);

    return pbuffer;
}

BOOLEAN 
StartsWith (
    IN const UNICODE_STRING* Str,
    IN const UNICODE_STRING* Prefix) 
{
    LPCWSTR prefixToUse = NULL, stringToCompareTo = NULL;

    if (Prefix == NULL || Prefix->Length == 0)
        return TRUE;

    if (Str == NULL || Prefix->Length > Str->Length)
        return FALSE;

    prefixToUse = Prefix->Buffer;
    stringToCompareTo = Str->Buffer;

    while (*prefixToUse) 
    {
        if (*prefixToUse++ != *stringToCompareTo++)
            return FALSE;
    }

    return TRUE;
}

BOOLEAN 
StartsWithDosDevicesPrefix(
    IN const UNICODE_STRING* Str) 
{
    return StartsWith(Str, &g_DosDevicesPrefix);
}

BOOLEAN 
StartsWithVolumeGuidPrefix(
    IN const UNICODE_STRING* Str) 
{
    return StartsWith(Str, &g_VolumeGuidPrefix);
}

BOOLEAN 
IsMountPointDriveLetter(
    IN const UNICODE_STRING* punstrMountPoint
) 
{
    size_t colonIndex = g_DosDevicesPrefix.Length / sizeof(WCHAR) + 1;
    size_t driveLetterLength = g_DosDevicesPrefix.Length + 2 * sizeof(WCHAR);
    BOOLEAN bNonTerminatedDriveLetterLength = punstrMountPoint->Length == driveLetterLength;
    BOOLEAN bNullTerminatedDriveLetterLength = punstrMountPoint->Length == driveLetterLength + sizeof(WCHAR) && punstrMountPoint->Buffer[colonIndex + 1] == L'\0';
    return StartsWithDosDevicesPrefix(punstrMountPoint) && (bNonTerminatedDriveLetterLength || bNullTerminatedDriveLetterLength) && punstrMountPoint->Buffer[colonIndex] == L':';
}

UNICODE_STRING 
PTFSWrapUnicodeString(
    IN WCHAR* Buffer,
    IN USHORT Length
)
{
    UNICODE_STRING unstr;
    unstr.Buffer = Buffer;
    unstr.Length = Length;
    unstr.MaximumLength = Length;
    return unstr;
}

ULONG
PointerAlignSize(
    IN ULONG ulSizeInBytes) 
{
    ULONG ulRemainder = ulSizeInBytes & (sizeof(void*) - 1);

    if (ulRemainder > 0)
        return ulSizeInBytes + (sizeof(void*) - ulRemainder);

    return ulSizeInBytes;
}

LONG 
SearchStringWChar(
    IN PWCHAR pwszString, 
    IN ULONG Length,
    IN WCHAR wChar) 
{
    ULONG i = 0;

    for ( i = 0; i < Length / sizeof(WCHAR); ++i) 
    {
        if (pwszString[i] == wChar) 
            return i;
    }

    return -1;
}

LONG 
SearchUnicodeStringWChar(
    IN PUNICODE_STRING pUnicodeString,
    IN WCHAR wChar)
{
    return SearchStringWChar(pUnicodeString->Buffer, pUnicodeString->Length, wChar);
}

VOID 
PTFSSetFlag(
    IN PULONG Flags, 
    IN ULONG FlagBit) 
{
    ULONG _ret = InterlockedOr((PLONG)Flags, FlagBit);
    UNREFERENCED_PARAMETER(_ret);
    ASSERT(*Flags == (_ret | FlagBit));
}

VOID PTFSClearFlag(
    IN PULONG Flags, 
    IN ULONG FlagBit) 
{
    ULONG _ret = InterlockedAnd((PLONG)Flags, ~FlagBit);
    UNREFERENCED_PARAMETER(_ret);
    ASSERT(*Flags == (_ret & (~FlagBit)));
}

PUNICODE_STRING
PTFSAllocateUnicodeString(
    IN PCWSTR String) 
{
    PUNICODE_STRING unicode = NULL;
    PWSTR buffer = NULL;
    ULONG length = 0;
    NTSTATUS result;

    unicode = PTFSAllocate(sizeof(UNICODE_STRING));
    if (unicode == NULL)
        return NULL;

    length = (ULONG)(wcslen(String) + 1) * sizeof(WCHAR);
    buffer = PTFSAllocate(length);
    if (buffer == NULL)
    {
        PTFSFree(unicode);
        return NULL;
    }

    RtlCopyMemory(buffer, String, length);
    result = RtlUnicodeStringInitEx(unicode, buffer, 0);

    if (!NT_SUCCESS(result)) 
    {
        KdPrint(("[PTFS]::PTFSAllocateUnicodeString invalid string size received.\n"));
        PTFSFree(buffer);
        PTFSFree(unicode);
        return NULL;
    }

    return unicode;
}

PUNICODE_STRING 
PTFSAllocDuplicateString(
    IN const PUNICODE_STRING Src
) 
{
    PUNICODE_STRING result = PTFSAllocateZero(sizeof(UNICODE_STRING));
    
    if (!result)
        return NULL;

    if (!PTFSDuplicateUnicodeString(result, Src)) 
    {
        PTFSFree(result);
        return NULL;
    }
    return result;
}

BOOLEAN
PTFSDuplicateUnicodeString(
	OUT PUNICODE_STRING Dest,
	IN const PUNICODE_STRING Src
)
{
	if (Dest->Buffer)
        PTFSFree(Dest->Buffer);

    Dest->Buffer = PTFSAllocate(Src->MaximumLength);
	if (!Dest->Buffer) 
    {
		Dest->Length = 0;
		Dest->MaximumLength = 0;
		return FALSE;
	}

	Dest->MaximumLength = Src->MaximumLength;
	Dest->Length = Src->Length;
	RtlCopyMemory(Dest->Buffer, Src->Buffer, Dest->MaximumLength);
	return TRUE;
}

ULONG 
SearchWcharinUnicodeStringWithUlong(
    IN PUNICODE_STRING inputPUnicodeString, 
    IN WCHAR targetWchar,
    IN ULONG offsetPosition,
    IN int isIgnoreTargetWchar)
{
    ASSERT(inputPUnicodeString != NULL);

    if (offsetPosition > inputPUnicodeString->MaximumLength)
        offsetPosition = inputPUnicodeString->Length;

    while (offsetPosition > 0) 
    {
        offsetPosition--;
        if (inputPUnicodeString->Buffer[offsetPosition] == targetWchar) 
        {
            if (isIgnoreTargetWchar == 1)
                offsetPosition++;
            break;
        }
    }

    return offsetPosition;
}

PUNICODE_STRING 
ChangePrefix(
    IN const UNICODE_STRING* Str,
    IN const UNICODE_STRING* Prefix, 
    IN BOOLEAN HasPrefix,
    IN const UNICODE_STRING* NewPrefix) 
{
    PUNICODE_STRING newStr = NULL;
    BOOLEAN startWithPrefix = FALSE;
    USHORT prefixLength = 0;
    USHORT length = 0;
    UNICODE_STRING strAfterPrefix;

    startWithPrefix = StartsWith(Str, Prefix);
    if (!startWithPrefix && HasPrefix) 
    {
        KdPrint(("[PTFS]::ChangePrefix %wZ do not start with Prefix %wZ\n", Str, Prefix));
        return NULL;
    }

    length = Str->Length + NewPrefix->Length;
    if (startWithPrefix) 
    {
        prefixLength = Prefix->Length;
        length -= Prefix->Length;
    }
    newStr = PTFSAllocateZero(sizeof(UNICODE_STRING));
    if (!newStr)
    {
        KdPrint(("[PTFS]::ChangePrefix Failed to allocate unicode_string\n"));
        return NULL;
    }

    newStr->Length = 0;
    newStr->MaximumLength = length;
    newStr->Buffer = PTFSAllocateZero(length);

    if (!newStr->Buffer) 
    {
        KdPrint(("[PTFS]::ChangePrefix Failed to allocate unicode_string buffer\n"));
        PTFSFree(newStr);
        return NULL;
    }

    RtlUnicodeStringCopy(newStr, NewPrefix);
    strAfterPrefix = PTFSWrapUnicodeString((PWCHAR)((PCHAR)Str->Buffer + prefixLength), Str->Length - prefixLength);
    RtlUnicodeStringCat(newStr, &strAfterPrefix);
    return newStr;
}

VOID 
RunAsSystem(
    IN PKSTART_ROUTINE StartRoutine, 
    PVOID pStartContext) 
{
    HANDLE handle = NULL;
    PKTHREAD thread = NULL;
    OBJECT_ATTRIBUTES objectAttribs;
    NTSTATUS status;

    InitializeObjectAttributes(&objectAttribs, NULL, OBJ_KERNEL_HANDLE, NULL,NULL);
    status = PsCreateSystemThread(&handle, THREAD_ALL_ACCESS, &objectAttribs, NULL, NULL, StartRoutine, pStartContext);
    if (!NT_SUCCESS(status)) 
        KdPrint(("[PTFS]::RunAsSystem Failed to PsCreateSystemThread status[0x%x]\n", status));
    else 
    {
        ObReferenceObjectByHandle(handle, THREAD_ALL_ACCESS, NULL, KernelMode, &thread, NULL);
        ZwClose(handle);
        KeWaitForSingleObject(thread, Executive, KernelMode, FALSE, NULL);
        ObDereferenceObject(thread);
    }
}

NTSTATUS 
NotifyReportChange0(
    IN PPTFSFCB pFcb,
    IN PUNICODE_STRING FileName,
    IN ULONG FilterMatch,
    IN ULONG Action) 
{
    USHORT nameOffset;

    KdPrint(("[PTFS]::NotifyReportChange0 Start\n"));

    ASSERT(pFcb != NULL);
    ASSERT(FileName != NULL);

    if (SearchUnicodeStringWChar(FileName, L':') != -1) 
    {
        switch (Action) {
        case FILE_ACTION_ADDED:
            Action = FILE_ACTION_ADDED_STREAM;
            break;
        case FILE_ACTION_REMOVED:
            Action = FILE_ACTION_REMOVED_STREAM;
            break;
        case FILE_ACTION_MODIFIED:
            Action = FILE_ACTION_MODIFIED_STREAM;
            break;
        default:
            break;
        }

        if (FlagOn(FilterMatch,FILE_NOTIFY_CHANGE_DIR_NAME | FILE_NOTIFY_CHANGE_FILE_NAME))
            SetFlag(FilterMatch, FILE_NOTIFY_CHANGE_STREAM_NAME);
        if (FlagOn(FilterMatch, FILE_NOTIFY_CHANGE_SIZE))
            SetFlag(FilterMatch, FILE_NOTIFY_CHANGE_STREAM_SIZE);
        if (FlagOn(FilterMatch, FILE_NOTIFY_CHANGE_LAST_WRITE))
            SetFlag(FilterMatch, FILE_NOTIFY_CHANGE_STREAM_WRITE);

        ClearFlag(FilterMatch, ~(FILE_NOTIFY_CHANGE_STREAM_NAME | FILE_NOTIFY_CHANGE_STREAM_SIZE | FILE_NOTIFY_CHANGE_STREAM_WRITE));
    }

    nameOffset = (USHORT)(FileName->Length / sizeof(WCHAR) - 1);

    nameOffset = (USHORT)(SearchWcharinUnicodeStringWithUlong(FileName, L'\\', (ULONG)nameOffset, 1));
    nameOffset *= sizeof(WCHAR);

    __try 
    {
        FsRtlNotifyFullReportChange(pFcb->pVcb->NotifySync, &pFcb->pVcb->DirNotifyList,
            (PSTRING)FileName, nameOffset,
            NULL,
            NULL,
            FilterMatch,
            Action,
            NULL);
    }
    __except (GetExceptionCode() == STATUS_ACCESS_VIOLATION ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH) 
    {
        KdPrint(("[PTFS]::NotifyReportChange0 to Failed FsRtlNotifyFullReportChange Exception\n"));
        return STATUS_OBJECT_NAME_INVALID;
    }

    KdPrint(("[PTFS]::NotifyReportChange0 End\n"));
    return STATUS_SUCCESS;
}

NTSTATUS 
NotifyReportChange(
    IN PPTFSFCB pFcb, 
    IN ULONG FilterMatch,
    IN ULONG Action
) 
{
    ASSERT(pFcb != NULL);
    return NotifyReportChange0(pFcb, &pFcb->unstrFileName, FilterMatch, Action);
}

NTSTATUS
PTFSCheckOplock(
	IN PPTFSFCB pFcb,
	IN PIRP pIrp,
	IN OPTIONAL PVOID Context,
	IN OPTIONAL POPLOCK_WAIT_COMPLETE_ROUTINE CompletionRoutine,
	IN OPTIONAL POPLOCK_FS_PREPOST_IRP PostIrpRoutine
)
{
	ASSERT(pFcb->pVcb != NULL);
	ASSERT(pFcb->pVcb->pDcb != NULL);

	if (pFcb->pVcb != NULL && pFcb->pVcb->pDcb != NULL && !pFcb->pVcb->pDcb->bOplocksDisabled)
		return FsRtlCheckOplock(PTFSGetFcbOplock(pFcb), pIrp, Context, CompletionRoutine, PostIrpRoutine);

	return STATUS_SUCCESS;
}

BOOLEAN
PTFSCheckCCB(
    IN PPTFSDCB pDcb, 
    IN OPTIONAL PPTFSCCB pCcb
) 
{
    PPTFSVCB pVcb = NULL;
    ASSERT(pDcb != NULL);

    if (GETIDENTIFIERTYPE(pDcb) != DCB)
    {
        KdPrint(("[PTFS]::PTFSCheckCCB Invalid identifier type\n"));
        return FALSE;
    }

    if (pCcb == NULL)
    {
        KdPrint(("[PTFS]::PTFSCheckCCB null Ccb\n"));
        return FALSE;
    }

    if (pCcb->ulMountId != pDcb->ulMountId)
    {
        KdPrint(("[PTFS]::PTFSCheckCCB Invalid MountID ccb MountID[%d], Dcb MountID[%d]\n", pCcb->ulMountId, pDcb->ulMountId));
        return FALSE;
    }

    pVcb = pDcb->pVcb;
    if (!pVcb || IsUnmountPendingVcb(pVcb))
    {
        if (!pVcb)
        {
            KdPrint(("[PTFS]::PTFSCheckCCB null Vcb\n"));
        }
        else
            KdPrint(("[PTFS]::PTFSCheckCCB Pending Unmount\n"));

        return FALSE;
    }

    return TRUE;
}

NTSTATUS
AllocateMdl(
    IN PIRP pIrp, 
    IN ULONG Length
) 
{
    if (pIrp->MdlAddress == NULL) 
    {
        pIrp->MdlAddress = IoAllocateMdl(pIrp->UserBuffer, Length, FALSE, FALSE, pIrp);
        if (pIrp->MdlAddress == NULL) 
        {
            KdPrint(("[PTFS]::AllocateMdl IoAllocateMdl returned NULL\n"));
            return STATUS_INSUFFICIENT_RESOURCES;
        }

        __try 
        {
            MmProbeAndLockPages(pIrp->MdlAddress, pIrp->RequestorMode, IoWriteAccess);
        }
        __except (EXCEPTION_EXECUTE_HANDLER) 
        {
            KdPrint(("[PTFS]::AllocateMdl MmProveAndLockPages error\n"));
            IoFreeMdl(pIrp->MdlAddress);
            pIrp->MdlAddress = NULL;
            return STATUS_INSUFFICIENT_RESOURCES;
        }
    }
    return STATUS_SUCCESS;
}

VOID FreeMdl(
    IN PIRP pIrp
) 
{
    if (pIrp->MdlAddress != NULL) 
    {
        MmUnlockPages(pIrp->MdlAddress);
        IoFreeMdl(pIrp->MdlAddress);
        pIrp->MdlAddress = NULL;
    }
}

PVOID
PTFSMapUserBuffer (
    IN OUT PIRP pIrp
)
{
    if (pIrp->MdlAddress == NULL) 
        return pIrp->UserBuffer;
    else
        return MmGetSystemAddressForMdlSafe(pIrp->MdlAddress, NormalPagePriority);;
}
