Skip to content

Commit

Permalink
Add logfile and DTLS 1.3 functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
julek-wolfssl committed Oct 7, 2022
1 parent d232e94 commit b7c7d2e
Showing 1 changed file with 110 additions and 39 deletions.
149 changes: 110 additions & 39 deletions udp_proxy.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#define SOCKET_T int
#define SOCKLEN_T socklen_t
#define MY_EX_USAGE EX_USAGE
#define MY_EX_IOERR EX_IOERR
#define StartUDP()
#define INVALID_SOCKET (-1)
#else
Expand All @@ -56,14 +57,15 @@
#define SOCKET_T SOCKET
#define SOCKLEN_T int
#define MY_EX_USAGE 2
#define MY_EX_IOERR 3
#define StartUDP() { WSADATA wsd; WSAStartup(0x0002, &wsd); }
#endif

#include <event2/event.h>


/* datagram msg size */
#define MSG_SIZE 1500
#define MSG_SIZE 2000

#define SET_YELLOW printf("\033[0;33m")
#define SET_BLUE printf("\033[0;34m")
Expand All @@ -83,9 +85,22 @@ int delayByOne = 0; /* delay packet by 1 */
int dupePackets = 0; /* duplicate all packets */
int retxPacket = 0; /* specific seq to retransmit */
int injectAlert = 0; /* inject an alert at end of epoch 0 */
int isDtls13 = 0;
const char* selectedSide = NULL; /* Forced side to use */
const char* seqOrder = ""; /* how to reorder 0th epoch packets */

#define LOG(...) \
do { \
if (fp != NULL) { \
fprintf(fp, __VA_ARGS__); \
fflush(fp); \
} \
else \
printf(__VA_ARGS__); \
} while(0)
FILE *fp = NULL;
const char* logFile = NULL;

typedef struct proxy_ctx {
SOCKET_T clientFd; /* from client to proxy, downstream */
SOCKET_T serverFd; /* form server to proxy, upstream */
Expand Down Expand Up @@ -180,7 +195,6 @@ static int GetOpt(int argc, char** argv, const char* optstring)
return c;
}


static char* GetRecordType(const char* msg)
{
if (msg[0] == 0x16) {
Expand Down Expand Up @@ -242,9 +256,9 @@ static void IncrementRecordSeq(char* msg)
unsigned long seq = (int)( msg[7] << 24 | msg[8] << 16 |
msg[9] << 8 | msg[10] );

printf(" old seq: %lu\n", seq);
LOG(" old seq: %lu\n", seq);
seq++;
printf(" new seq: %lu\n", seq);
LOG(" new seq: %lu\n", seq);

msg[7] = (char)(seq >> 24);
msg[8] = (char)(seq >> 16);
Expand All @@ -253,27 +267,41 @@ static void IncrementRecordSeq(char* msg)
}
}

static void logMsg(char* side, char* msg, int msgSz)
static void logMsg(char* side, char* msg, int msgSz, int pktIdx)
{
printf("%s: E: %d Seq: %2d handshake: %2d got %s read %d bytes\n", side, GetRecordEpoch(msg), GetRecordSeq(msg), msg[18], GetRecordType(msg), msgSz);
if (!isDtls13)
LOG("%s: E: %d Seq: %2d handshake: %2d got %s read %d bytes\n", side,
GetRecordEpoch(msg), GetRecordSeq(msg), msg[18],
GetRecordType(msg), msgSz);
else
LOG("%d: %s: read %d bytes\n", pktIdx, side, msgSz);
}

typedef struct pkt {
char bin[MSG_SIZE];
int binSz;
int pktIdx;
struct pkt* next;
} pkt;
static pkt* pktStore = NULL;

static void pushPkt(char* msg, int msgSz)
static void pushPkt(char* msg, int msgSz, int peerIdx)
{
if (msg && msgSz > 0) {
pkt* tmp;
pkt* new = (pkt*)malloc(sizeof(pkt));
if (new == NULL)
return;
printf("Storing pkt with seq %d\n", GetRecordSeq(msg));
if (!isDtls13)
LOG("Storing pkt with seq %d\n", GetRecordSeq(msg));
else
LOG("Storing pkt %d\n", peerIdx);
memset(new, 0, sizeof(pkt));
new->pktIdx = peerIdx;
if (msgSz > MSG_SIZE) {
LOG("Truncating saved packet");
msgSz = MSG_SIZE;
}
memcpy(new->bin, msg, msgSz);
new->binSz = msgSz;
if (pktStore == NULL) {
Expand All @@ -293,7 +321,7 @@ static void pktStoreDrain(char* side, SOCKET_T peerFd) {
pkt* prev = NULL;
pktStore = NULL;
while (tmp != NULL) {
logMsg(side, tmp->bin, tmp->binSz);
logMsg(side, tmp->bin, tmp->binSz, tmp->pktIdx);
send(peerFd, tmp->bin, tmp->binSz, 0);
prev = tmp;
tmp = tmp->next;
Expand All @@ -307,8 +335,9 @@ static void pktStoreSend(char* side, SOCKET_T peerFd) {
pkt* prev = NULL;
int seq = *seqOrder - '0';
while (tmp != NULL) {
if (GetRecordSeq(tmp->bin) == seq) {
logMsg(side, tmp->bin, tmp->binSz);
if ((isDtls13 && tmp->pktIdx == seq) ||
(!isDtls13 && GetRecordSeq(tmp->bin) == seq)) {
logMsg(side, tmp->bin, tmp->binSz, tmp->pktIdx);
send(peerFd, tmp->bin, tmp->binSz, 0);
seqOrder++;
if (prev != NULL)
Expand All @@ -332,45 +361,68 @@ static void pktStoreSend(char* side, SOCKET_T peerFd) {
static void Msg(evutil_socket_t fd, short which, void* arg)
{
static int msgCount = 0;
static int peerIdx[2] = {-1, -1}; /* Number of packets seen from peer.
* [0] client [1] server */

char msg[MSG_SIZE];
proxy_ctx* ctx = (proxy_ctx*)arg;
int ret = recv(fd, msg, MSG_SIZE, 0);

if (ret == 0)
printf("read 0\n");
LOG("read 0\n");
else if (ret < 0)
printf("read < 0\n");
LOG("read < 0\n");
else {
SOCKET_T peerFd;
char* side; /* from message side */
int sideIdx;

if (ctx->serverFd == fd) {
peerFd = ctx->clientFd;
side = serverSide;
sideIdx = 1;
}
else {
peerFd = ctx->serverFd;
side = clientSide;
sideIdx = 0;
}

if (side == selectedSide && GetRecordEpoch(msg) == 0
&& *seqOrder != '\0') {
int seq = *seqOrder - '0';
if (GetRecordSeq(msg) != seq) {
pushPkt(msg, ret);
return;
peerIdx[sideIdx]++;

if (!isDtls13) {
if (side == selectedSide && GetRecordEpoch(msg) == 0
&& *seqOrder != '\0') {
int seq = *seqOrder - '0';
if (GetRecordSeq(msg) != seq) {
pushPkt(msg, ret, -1);
return;
}
else {
seqOrder++;
}
}
else {
seqOrder++;
}
else {
/* No way of knowing what the sequence number is so just blindly
* re-order the encrypted packets */
if (side == selectedSide && *seqOrder != '\0') {
int seq = *seqOrder - '0';
if (peerIdx[sideIdx] != seq) {
pushPkt(msg, ret, peerIdx[sideIdx]);
return;
}
else {
seqOrder++;
}
}
}

if (side == serverSide)
SET_BLUE;
else
SET_YELLOW;
logMsg(side, msg, ret);
logMsg(side, msg, ret, peerIdx[sideIdx]);
RESET_COLOR;

msgCount++;
Expand All @@ -380,11 +432,11 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
GetRecordSeq(msg) == delayByOne &&
side == selectedSide) {

printf("*** delaying server packet %d\n", delayByOne);
LOG("*** delaying server packet %d\n", delayByOne);
if (currDelay == NULL)
currDelay = &tmpDelay;
else {
printf("*** oops, still have a packet in delay\n");
LOG("*** oops, still have a packet in delay\n");
assert(0);
}
memcpy(currDelay->msg, msg, ret);
Expand All @@ -397,7 +449,7 @@ static void Msg(evutil_socket_t fd, short which, void* arg)

/* is it now time to send along delayed packet */
if (delayPacket && currDelay && currDelay->sendCount == msgCount) {
printf("*** sending on delayed packet\n");
LOG("*** sending on delayed packet\n");
send(currDelay->peerFd, currDelay->msg, currDelay->msgLen, 0);
currDelay = NULL;
}
Expand All @@ -406,22 +458,22 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
if (dropSpecific && side == selectedSide &&
GetRecordEpoch(msg) == dropSpecificEpoch &&
GetRecordSeq(msg) == dropSpecificSeq) {
printf("*** but dropping this packet specifically\n");
LOG("*** but dropping this packet specifically\n");
return;
}

if (dropNth && dropPacketNo == msgCount) {
printf("*** but dropping the %d packet\n", msgCount);
LOG("*** but dropping the %d packet\n", msgCount);
return;
}

/* should we delay the current packet */
if (delayPacket && (msgCount % delayPacket) == 0) {
printf("*** but delaying this packet\n");
LOG("*** but delaying this packet\n");
if (currDelay == NULL)
currDelay = &tmpDelay;
else {
printf("*** oops, still have a packet in delay\n");
LOG("*** oops, still have a packet in delay\n");
assert(0);
}
memcpy(currDelay->msg, msg, ret);
Expand All @@ -435,7 +487,7 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
/* should we drop current packet altogether */
if (dropPacket && (msgCount % dropPacket) == 0
&& msg[0] != 0x17 /* But don't drop application data */) {
printf("*** but dropping this packet\n");
LOG("*** but dropping this packet\n");
return;
}

Expand All @@ -447,7 +499,7 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
SET_BLUE;
else
SET_YELLOW;
if (GetRecordEpoch(msg) == 0 && *seqOrder != '\0')
if ((isDtls13 || GetRecordEpoch(msg) == 0) && *seqOrder != '\0')
pktStoreSend(side, peerFd);
else
pktStoreDrain(side, peerFd);
Expand All @@ -460,7 +512,7 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
injectAlert = 2;
}
if (injectAlert == 2 && side == serverSide && msg[0] == 0x14) {
printf("*** injecting a bogus alert from client after "
LOG("*** injecting a bogus alert from client after "
"change cipher spec\n");
ret = send(ctx->serverFd, bogusAlert, sizeof(bogusAlert), 0);
if (ret < 0) {
Expand Down Expand Up @@ -489,7 +541,7 @@ static void Msg(evutil_socket_t fd, short which, void* arg)
side == selectedSide &&
currDelay) {

printf("*** sending on delayed packet\n");
LOG("*** sending on delayed packet\n");
send(currDelay->peerFd, currDelay->msg, currDelay->msgLen, 0);
currDelay = NULL;
}
Expand Down Expand Up @@ -518,7 +570,7 @@ static void newClient(evutil_socket_t fd, short which, void* arg)
'connection' again, also allows pairing with upStream 'connect' */
msgLen = recvfrom(fd, msg, MSG_SIZE, 0, (struct sockaddr*)&client, &len);
SET_YELLOW;
printf("%s: got %s, first msg\n", clientSide, GetRecordType(msg));
LOG("%s: got %s, first msg\n", clientSide, GetRecordType(msg));
RESET_COLOR;
ctx->clientFd = socket(AF_INET, SOCK_DGRAM, 0);
if (ctx->clientFd == INVALID_SOCKET) {
Expand Down Expand Up @@ -574,7 +626,7 @@ static void newClient(evutil_socket_t fd, short which, void* arg)
event_add(srvEvent, NULL);

if (dropNth && dropPacketNo == 0) {
printf("*** but dropping this packet\n");
LOG("*** but dropping this packet\n");
return;
}

Expand Down Expand Up @@ -607,6 +659,8 @@ static void Usage(void)
printf("-r <pkt seq> Re-order packets from zeroth epoch in this order\n"
" ex: 146523\n");
printf("-S <client|server> Force side (default: server)\n");
printf("-u Interpret traffic as DTLS 1.3\n");
printf("-l <log file> Use the provided argument as the log file\n");
}


Expand All @@ -618,7 +672,7 @@ int main(int argc, char** argv)
short port = -1;
char* serverString = NULL;

while ( (ch = GetOpt(argc, argv, "?Dap:s:d:y:x:b:R:S:r:f:")) != -1) {
while ( (ch = GetOpt(argc, argv, "?Dap:s:d:y:x:b:R:S:r:f:ul:")) != -1) {
switch (ch) {
case '?' :
Usage();
Expand Down Expand Up @@ -690,21 +744,38 @@ int main(int argc, char** argv)
dropNth = 1;
dropPacketNo = atoi(myoptarg);
break;

case 'u':
isDtls13 = 1;
break;

case 'l':
logFile = myoptarg;
break;

default:
Usage();
exit(MY_EX_USAGE);
break;
}
}

if (logFile != NULL) {
fp = fopen(logFile, "w");
if (fp == NULL) {
LOG("Can't open log file\n");
exit(MY_EX_IOERR);
}
}

if (port == -1) {
printf("need to set 'listen port'\n");
LOG("need to set 'listen port'\n");
Usage();
exit(MY_EX_USAGE);
}

if (serverString == NULL) {
printf("need to set server address string\n");
LOG("need to set server address string\n");
Usage();
exit(MY_EX_USAGE);
}
Expand Down Expand Up @@ -760,7 +831,7 @@ int main(int argc, char** argv)

event_base_dispatch(base);

printf("done with dispatching\n");
LOG("done with dispatching\n");

return 0;
}

0 comments on commit b7c7d2e

Please sign in to comment.