#include "irp_buffer_helper.h"
#include "except.h"

ULONG
GetProvidedInputSize(
	IN PIRP pIrp
)
{
	PIO_STACK_LOCATION pIrpSp = IoGetCurrentIrpStackLocation(pIrp);

	switch (pIrpSp->MajorFunction)
	{
	case IRP_MJ_DEVICE_CONTROL:
		return pIrpSp->Parameters.DeviceIoControl.InputBufferLength;
	case IRP_MJ_FILE_SYSTEM_CONTROL:
		return pIrpSp->Parameters.FileSystemControl.InputBufferLength;
	default:
		return 0;
	}
}

ULONG
GetProvidedOutputSize(
	IN PIRP pIrp
)
{
	PIO_STACK_LOCATION pIrpSp = IoGetCurrentIrpStackLocation(pIrp);
	switch (pIrpSp->MajorFunction) {
	case IRP_MJ_DEVICE_CONTROL:
		return pIrpSp->Parameters.DeviceIoControl.OutputBufferLength;
	case IRP_MJ_DIRECTORY_CONTROL:
		return pIrpSp->Parameters.QueryDirectory.Length;
	case IRP_MJ_FILE_SYSTEM_CONTROL:
		return pIrpSp->Parameters.FileSystemControl.OutputBufferLength;
	case IRP_MJ_QUERY_INFORMATION:
		return pIrpSp->Parameters.QueryFile.Length;
	case IRP_MJ_QUERY_SECURITY:
		return pIrpSp->Parameters.QuerySecurity.Length;
	default:
		return 0;
	}
}

PVOID
GetInputBuffer(
	IN PIRP pIrp
)
{
	static const ULONG methodMask = METHOD_BUFFERED | METHOD_IN_DIRECT | METHOD_NEITHER;
	PIO_STACK_LOCATION pIrpSp = IoGetCurrentIrpStackLocation(pIrp);
	PCHAR pBuffer = pIrp->AssociatedIrp.SystemBuffer;

	if (pIrpSp->MajorFunction == IRP_MJ_DEVICE_CONTROL &&
		(pIrpSp->Parameters.DeviceIoControl.IoControlCode & methodMask) == METHOD_NEITHER)
	{
		pBuffer = pIrpSp->Parameters.DeviceIoControl.Type3InputBuffer;
	}

	if (pIrpSp->MajorFunction == IRP_MJ_FILE_SYSTEM_CONTROL &&
		(pIrpSp->Parameters.FileSystemControl.FsControlCode & methodMask) == METHOD_NEITHER)
	{
		pBuffer = pIrpSp->Parameters.FileSystemControl.Type3InputBuffer;
	}

	if (pIrp->RequestorMode != KernelMode && pBuffer != NULL && pBuffer != pIrp->AssociatedIrp.SystemBuffer)
	{
		__try 
        {
			ProbeForRead(pBuffer, GetProvidedInputSize(pIrp), sizeof(char));
		}
		__except (PTFSExceptionFilter(pIrp, GetExceptionInformation())) 
        {
			pBuffer = NULL;
		}
	}

	return pBuffer;
}

PVOID
GetOutputBuffer(
	IN PIRP pIrp
)
{
	static const ULONG methodMask = METHOD_BUFFERED | METHOD_OUT_DIRECT | METHOD_NEITHER;
	PIO_STACK_LOCATION pIrpSp = IoGetCurrentIrpStackLocation(pIrp);
	PCHAR pBuffer = pIrp->AssociatedIrp.SystemBuffer;

	if (pIrpSp->MajorFunction == IRP_MJ_DEVICE_CONTROL &&
		(pIrpSp->Parameters.DeviceIoControl.IoControlCode & methodMask) == METHOD_NEITHER)
	{
		pBuffer = pIrp->UserBuffer;
	}

	if (pIrpSp->MajorFunction == IRP_MJ_FILE_SYSTEM_CONTROL &&
		(pIrpSp->Parameters.FileSystemControl.FsControlCode & methodMask) == METHOD_NEITHER)
	{
		pBuffer = pIrp->UserBuffer;
	}

	if (pIrpSp->MajorFunction == IRP_MJ_QUERY_SECURITY)
		pBuffer = pIrp->UserBuffer;

	if (pIrpSp->MajorFunction == IRP_MJ_DIRECTORY_CONTROL)
	{
		if (pIrp->MdlAddress)
			pBuffer = MmGetSystemAddressForMdlNormalSafe(pIrp->MdlAddress);
		else
			pBuffer = pIrp->UserBuffer;
	}

	if (pIrp->RequestorMode != KernelMode && pBuffer == pIrp->UserBuffer)
	{
		__try
		{
			ProbeForWrite(pBuffer, GetProvidedOutputSize(pIrp), sizeof(char));
		}
		__except (PTFSExceptionFilter(pIrp, GetExceptionInformation()))
		{
			pBuffer = NULL;
		}
	}
	return pBuffer;
}

PVOID
PrepareOutputWithSize(
	IN OUT PIRP pIrp,
	IN ULONG Size,
	IN BOOLEAN SetInformationOnFailure)
{
	PCHAR pBuffer = GetOutputBuffer(pIrp);
	ULONG providedSize = GetProvidedOutputSize(pIrp);

	if (pBuffer == NULL)
	{
		KdPrint(("[PTFS]::PrepareOutputWithSize Null GetOutputBuffer Size[%d]\n", providedSize));
		return NULL;
	}

	if (providedSize < Size)
	{
		pIrp->IoStatus.Information = SetInformationOnFailure ? Size : 0;
		KdPrint(("[PTFS]::PrepareOutputWithSize Invalid Size[%d(need), %d]\n", Size, providedSize));
		return NULL;
	}

	RtlZeroMemory(pBuffer, Size);
	pIrp->IoStatus.Information = Size;
	return pBuffer;
}

BOOLEAN 
PrepareOutputHelper(
	IN OUT PIRP pIrp,
	OUT VOID** ppBuffer,
	IN ULONG Size,
	IN BOOLEAN SetInformationOnFailure
) 
{
	*ppBuffer = PrepareOutputWithSize(pIrp, Size, SetInformationOnFailure);
	return *ppBuffer != NULL;
}

BOOLEAN
ExtendOutputBufferBySize(
	IN OUT  PIRP pIrp,
	IN ULONG AdditionalSize,
	IN BOOLEAN UpdateInformationOnFailure)
{
	PCHAR pBuffer = GetOutputBuffer(pIrp);
	ULONG providedSize = GetProvidedOutputSize(pIrp);
	ULONG_PTR usedSize = pIrp->IoStatus.Information;

	if (pBuffer == NULL)
		return FALSE;

	if (providedSize < usedSize + AdditionalSize)
	{
		if (UpdateInformationOnFailure)
			pIrp->IoStatus.Information += AdditionalSize;
		return FALSE;
	}
	RtlZeroMemory(pBuffer + usedSize, AdditionalSize);
	pIrp->IoStatus.Information += AdditionalSize;
	return TRUE;
}

BOOLEAN
AppendVarSizeOutputString(
	IN OUT PIRP pIrp,
	IN OUT PVOID Dest,
	IN const UNICODE_STRING* Str,
	IN BOOLEAN UpdateInformationOnFailure,
	IN BOOLEAN FillSpaceWithPartialString
)
{
	PCHAR pBuffer = GetOutputBuffer(pIrp);
	ULONG_PTR allocatedSize = pIrp->IoStatus.Information;
	ULONG_PTR destOffset = 0, remainingSize = 0, copySize = 0;

	if ((PCHAR)Dest < pBuffer || (PCHAR)Dest > pBuffer + allocatedSize)
		return FALSE;

	if (Str->Length == 0)
		return TRUE;

	destOffset = (PCHAR)Dest - pBuffer;
	remainingSize = allocatedSize - destOffset;
	copySize = Str->Length;

	if (remainingSize < copySize) 
	{
		ULONG additionalSize = (ULONG)(copySize - remainingSize);
		if (!ExtendOutputBufferBySize(pIrp, additionalSize,UpdateInformationOnFailure)) 
		{
			if (FillSpaceWithPartialString) 
			{
				ULONG providedSize = GetProvidedOutputSize(pIrp);
				copySize = (providedSize - destOffset) & ~(ULONG_PTR)1;
				pIrp->IoStatus.Information = copySize + destOffset;
			}
			else 
				return FALSE;
		}
	}

	RtlCopyMemory(Dest, Str->Buffer, copySize);
	return copySize == Str->Length;
}
