1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| #include <windows.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
DWORD GetPageSize(void)
{
SYSTEM_INFO si = {0};
GetSystemInfo(&si);
return si.dwPageSize;
}
size_t GetSizeInPagesEx(size_t sizeInBytes, size_t cbPageSize)
{
return ((sizeInBytes + cbPageSize-1) / cbPageSize);
}
#if 0
size_t GetSizeInPages(size_t sizeInBytes)
{
return GetSizeInPagesEx(sizeInBytes, GetPageSize());
}
#endif
/*
Fonction qui utilise les allocations bas niveau de Windows
pour appeler gets() en toute sécurité.
*/
int super_gets(char *szDestBuffer, size_t cchDestSize)
{
int ret = -1;
/* Calcul des tailles: On réserve une page de plus que nécessaire,
pour avoir une page de protection. */
DWORD const cbPageSize = GetPageSize();
size_t const cbDestSize = cchDestSize * sizeof *szDestBuffer;
size_t const sizeInPages = GetSizeInPagesEx(cbDestSize, cbPageSize);
size_t const cbSizeWithoutGuardPage = sizeInPages * cbPageSize;
size_t const cbFullSize = (sizeInPages+1) * cbPageSize;
/* Note: Ici, ça marche aussi avec PAGE_READWRITE,
donc toute tentative d'accès à une mémoire non-committée provoque une exception.
Mais je trouve que PAGE_NOACCESS est plus explicite. */
void* pFullBuffer = VirtualAlloc(NULL, cbFullSize, MEM_RESERVE, PAGE_NOACCESS);
if(pFullBuffer != NULL)
{
void *pCommitted = VirtualAlloc(pFullBuffer, cbSizeWithoutGuardPage, MEM_COMMIT, PAGE_READWRITE);
if(pCommitted != NULL)
{
/* On écrit juste avant la page de protection,
comme ça si on dépasse, une exception est lancée. */
unsigned char * pbyBuf = pCommitted;
assert(pCommitted == pFullBuffer);
pbyBuf += cbSizeWithoutGuardPage;
pbyBuf -= cbDestSize;
{
void * const pBuf = pbyBuf;
char * szBuf = pBuf;
int bOverflow = 0;
size_t cchTypedLength;
/* On appelle gets() et on attrape l'exception en cas de dépassement */
__try
{
#pragma warning(push)
#pragma warning(disable:4996) /*désactive le warning sur gets()*/
gets(szBuf);
#pragma warning(pop)
}
__except((GetExceptionCode()==EXCEPTION_ACCESS_VIOLATION ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH))
{
bOverflow = 1;
}
if(bOverflow)
{
/* Il y a eu dépassement: On écrit toute la longueur du buffer. */
cchTypedLength = cchDestSize-1;
szBuf[cchTypedLength] = '\0';
}
else
{
/* Il n'y a pas eu dépassement: On écrit juste la taille demandée. */
cchTypedLength = strlen(szBuf);
}
assert(cchTypedLength < cchDestSize);
memcpy(szDestBuffer, szBuf, (cchTypedLength+1) * sizeof *szBuf);
ret = 0;
}
}
VirtualFree(pFullBuffer, 0, MEM_RELEASE), pFullBuffer=NULL;
}
return ret;
}
void TestGets(void)
{
int magic1 = 0x12345678;
char buf[8];
int magic2 = 0x5678ABCD;
puts("Entrer une chaine :");
super_gets(buf, ARRAYSIZE(buf));
assert(magic1 == 0x12345678);
assert(magic2 == 0x5678ABCD);
printf("Chaine entree : \"%s\" - Longueur : %lu.\n", buf, (unsigned long)strlen(buf));
} |
Partager