kpmpgsmkii/shared/packet.c
2021-12-06 21:31:09 -06:00

384 lines
12 KiB
C

/*
* Kangaroo Punch MultiPlayer Game Server Mark II
* Copyright (C) 2020-2021 Scott Duensing
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
#include "packet.h"
#include "primes.h"
static packetSender _packetSender = NULL;
static uint8_t packetCRC(char *data, uint16_t length, uint8_t startAt);
static uint16_t packetDHCompute(uint16_t a, uint16_t m, uint16_t n);
static uint8_t packetCRC(char *data, uint16_t length, uint8_t startAt) {
uint16_t x = 0;
for (x=0; x<length; x++) {
startAt ^= data[x]; // Good ole' XOR.
}
return startAt;
}
uint8_t packetDecode(PacketThreadDataT *threadData, PacketDecodeDataT *data, char *input, uint16_t length) {
uint8_t sequence = 0;
int32_t x = 0;
char c = 0;
PacketEncodeDataT encoded = { 0 };
// input and inputLength are incoming raw data or NULL and 0 to continue processing already received data.
// Returns 1 on packet ready, 0 on still waiting.
// Is there input data to add to the queue?
if (input != NULL && length > 0) {
for (x=0; x<length; x++) {
threadData->decodeQueue[threadData->decodeQueueHead++] = input[x];
if (threadData->decodeQueueHead >= PACKET_INPUT_QUEUE_SIZE) {
threadData->decodeQueueHead = 0;
}
}
}
while (1) {
// Do we have data to process?
if (threadData->decodeQueueHead == threadData->decodeQueueTail) return 0;
// Get next byte.
c = threadData->decodeQueue[threadData->decodeQueueTail++];
if (threadData->decodeQueueTail >= PACKET_INPUT_QUEUE_SIZE) {
threadData->decodeQueueTail = 0;
}
// New packet?
if (threadData->newPacket) {
threadData->newPacket = 0;
threadData->inEscape = 0;
data->length = 0;
}
// Are we escaped?
if (threadData->inEscape) {
threadData->inEscape = 0;
// Is this the end of the packet?
if (c != PACKET_FRAME) {
threadData->newPacket = 1;
// Check CRC.
if ((uint8_t)threadData->decodeBuffer[data->length - 1] != packetCRC(threadData->decodeBuffer, data->length - 1, 0)) continue;
// Get sequence value.
sequence = ((uint8_t)threadData->decodeBuffer[0]) & 0x1f;
// Is this a NAK?
if ((((uint8_t)threadData->decodeBuffer[0]) & 0xc0) >> 6 == PACKET_CONTROL_NAK) {
// Rewind until we find the packet we need in history.
x = threadData->historyPosition - 1;
if (x < 0) x = PACKET_SEQUENCE_MAX - 1;
while (threadData->history[x].sequence != sequence && x != threadData->historyPosition) {
x--;
if (x < 0) x = PACKET_SEQUENCE_MAX - 1;
}
// Did we find it?
if (x == threadData->historyPosition) {
// No! BAD!
logWrite("Unable to locate missing packet in history!\n\r");
} else {
// Yes. Replay missing packets.
while (x != threadData->historyPosition) {
logWrite("Resending %d!\n\r", threadData->history[x].sequence);
_packetSender(threadData->history[x].data, threadData->history[x].length, threadData->senderData);
x++;
if (x >= PACKET_SEQUENCE_MAX) x = 0;
}
}
continue;
}
// Is this the sequence number we're expecting?
if (sequence == threadData->lastRemoteSequence) {
// Yes!
threadData->lastRemoteSequence++;
} else {
// No! NAK it!
logWrite("Packet out of sequence! Got %d wanted %d!\n\r", sequence, threadData->lastRemoteSequence);
encoded.control = PACKET_CONTROL_NAK; // Negative acknowledge.
encoded.packetType = PACKET_TYPE_NONE; // Not destined for the app.
encoded.channel = 0; // Channel doesn't matter for NAK.
encoded.encrypt = 0; // Encryption doesn't matter for NAK.
encoded.sequence = threadData->lastRemoteSequence; // The last good packet we saw.
packetEncode(threadData, &encoded, NULL, 0);
packetSend(threadData, &encoded);
continue;
}
// Fill decoded data fields.
data->packetType = threadData->decodeBuffer[1];
data->channel = threadData->decodeBuffer[2];
// ***TODO*** Blowfish Decryption.
if (threadData->decodeBuffer[0] & 32) {
}
// Copy packet data to new buffer, if any.
if (data->length - 4 > 0) {
data->data = (char *)malloc(data->length - 4); // 4 for 3 byte header and 1 byte CRC.
if (!data) continue;
memcpy(data->data, &threadData->decodeBuffer[3], data->length - 4); // Skip header and CRC.
} else {
// No payload.
data->data = NULL;
}
// Fix length to remove header and checksum.
data->length -= 4;
// Is this a DH_REQUEST?
if (data->packetType == PACKET_TYPE_DH_REQUEST) {
memcpy(threadData->dhModulus, data->data, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
memcpy(threadData->dhBase, &data->data[PACKET_ENCRYPT_KEY_SIZE * 2], PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
memcpy(threadData->dhTheirPublic, &data->data[PACKET_ENCRYPT_KEY_SIZE * 4], PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
DEL(data->data);
logWrite("Server Key: ");
for (x=0; x<PACKET_ENCRYPT_KEY_SIZE; x++) {
threadData->dhMySecret[x] = rand();
threadData->dhMyPublic[x] = packetDHCompute(threadData->dhBase[x], threadData->dhMySecret[x], threadData->dhModulus[x]);
threadData->dhSharedKey[x] = packetDHCompute(threadData->dhTheirPublic[x], threadData->dhMySecret[x], threadData->dhModulus[x]);
logWrite("%d ", threadData->dhSharedKey[x]);
}
logWrite("\n\r");
encoded.control = PACKET_CONTROL_DAT;
encoded.packetType = PACKET_TYPE_DH_RESPONSE;
encoded.channel = 0; // Doesn't matter for DH_RESPONSE.
encoded.encrypt = 0; // Must be 0 for DH_RESPONSE.
packetEncode(threadData, &encoded, (char *)threadData->dhMyPublic, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
packetSend(threadData, &encoded);
continue;
}
// Is this a DH_RESPONSE?
if (data->packetType == PACKET_TYPE_DH_RESPONSE) {
memcpy(threadData->dhTheirPublic, data->data, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
DEL(data->data);
logWrite("Client Key: ");
for (x=0; x<PACKET_ENCRYPT_KEY_SIZE; x++) {
threadData->dhSharedKey[x] = packetDHCompute(threadData->dhTheirPublic[x], threadData->dhMySecret[x], threadData->dhModulus[x]);
logWrite("%d ", threadData->dhSharedKey[x]);
}
logWrite("\n\r");
continue;
}
// Done!
break;
}
} else {
// Are we escaping?
if (c == PACKET_FRAME) {
// Yes. Don't add this byte.
threadData->inEscape = 1;
continue;
}
// Did we overflow?
if (data->length >= PACKET_MAX) {
// Yup. Dump it.
threadData->newPacket = 1;
continue;
}
// Add byte to packet.
threadData->decodeBuffer[data->length++] = c;
}
}
return 1;
}
void packetDecodeDataDestroy(PacketDecodeDataT **packet) {
PacketDecodeDataT *d = *packet;
free(d->data);
free(d);
d = NULL;
*packet = d;
}
static uint16_t packetDHCompute(uint16_t a, uint16_t m, uint16_t n) {
// See: https://www.techiedelight.com/c-program-demonstrate-diffie-hellman-algorithm/
uint16_t r = 0;
uint16_t y = 1;
while (m > 0) {
r = m % 2;
// Fast exponention.
if (r == 1) {
y = (y * a) % n;
}
a = a * a % n;
m = m / 2;
}
return y;
}
uint8_t packetEncode(PacketThreadDataT *threadData, PacketEncodeDataT *data, char *input, uint16_t length) {
uint8_t crc = 0;
uint8_t control = 0;
// Returns 1 on success, 0 on failure.
// Packet too large?
if (length > PACKET_MAX) return 0;
data->dataPointer = NULL;
data->length = 0;
if (data->control == PACKET_CONTROL_DAT) {
data->sequence = threadData->sequence++;
}
// Make needed header bytes.
control = (((uint8_t)data->control) << 6) + (data->encrypt << 5) + (data->sequence & 0x1f);
// Calculate CRC over header bytes and payload.
crc = packetCRC((char *)&control, 1, crc);
crc = packetCRC((char *)&data->packetType, 1, crc);
crc = packetCRC((char *)&data->channel, 1, crc);
crc = packetCRC(input, length, crc);
// Add header bytes.
threadData->encodeBuffer[data->length++] = control;
if (control == PACKET_FRAME) threadData->encodeBuffer[data->length++] = PACKET_FRAME;
threadData->encodeBuffer[data->length++] = data->packetType;
if (data->packetType == PACKET_FRAME) threadData->encodeBuffer[data->length++] = PACKET_FRAME;
threadData->encodeBuffer[data->length++] = data->channel;
if (data->channel == PACKET_FRAME) threadData->encodeBuffer[data->length++] = PACKET_FRAME;
// ***TODO*** Blowfish Encryption.
if (data->encrypt) {
}
// Add payload.
while (length--) {
// Is this a frame character? If so, escape it.
if (*input == PACKET_FRAME) threadData->encodeBuffer[data->length++] = PACKET_FRAME;
// Add data.
threadData->encodeBuffer[data->length++] = *input++;
}
// Add CRC.
threadData->encodeBuffer[data->length++] = crc;
if (crc == PACKET_FRAME) threadData->encodeBuffer[data->length++] = PACKET_FRAME;
// Mark end of packet.
threadData->encodeBuffer[data->length++] = PACKET_FRAME;
threadData->encodeBuffer[data->length++] = 0;
data->dataPointer = threadData->encodeBuffer;
return 1;
}
void packetEncryptionSetup(PacketThreadDataT *threadData) {
PacketEncodeDataT encoded = { 0 };
uint16_t dhData[PACKET_ENCRYPT_KEY_SIZE * 3] = { 0 };
memcpy(&dhData[0], threadData->dhModulus, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
memcpy(&dhData[PACKET_ENCRYPT_KEY_SIZE], threadData->dhBase, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
memcpy(&dhData[PACKET_ENCRYPT_KEY_SIZE * 2], threadData->dhMyPublic, PACKET_ENCRYPT_KEY_SIZE * sizeof(uint16_t));
encoded.control = PACKET_CONTROL_DAT;
encoded.packetType = PACKET_TYPE_DH_REQUEST;
encoded.channel = 0; // Doesn't matter for DH_REQUEST.
encoded.encrypt = 0; // Must be 0 for DH_REQUEST.
packetEncode(threadData, &encoded, (char *)dhData, PACKET_ENCRYPT_KEY_SIZE * 3 * sizeof(uint16_t));
packetSend(threadData, &encoded);
}
void packetSend(PacketThreadDataT *threadData, PacketEncodeDataT *data) {
// Valid control type?
if (data->control != PACKET_CONTROL_BAD && data->control <= PACKET_CONTROL_COUNT) {
_packetSender(data->dataPointer, data->length, threadData->senderData);
} else {
logWrite("Invalid PACKET_CONTROL!\n\r");
}
// Add to history?
if (data->control == PACKET_CONTROL_DAT) {
threadData->history[threadData->historyPosition].sequence = data->sequence;
threadData->history[threadData->historyPosition].length = data->length;
memcpy(threadData->history[threadData->historyPosition].data, data->dataPointer, data->length);
threadData->historyPosition++;
if (threadData->historyPosition >= PACKET_SEQUENCE_MAX) {
threadData->historyPosition = 0;
}
}
// Mark invalid so caller has to change it.
data->control = PACKET_CONTROL_BAD;
}
void packetSenderRegister(packetSender sender) {
_packetSender = sender;
}
PacketThreadDataT *packetThreadDataCreate(void *senderData) {
PacketThreadDataT *data = NULL;
uint8_t x = 0;
data = (PacketThreadDataT *)malloc(sizeof(PacketThreadDataT));
if (data) {
data->sequence = 0;
data->lastRemoteSequence = 0;
data->historyPosition = 0;
data->decodeQueueHead = 0;
data->decodeQueueTail = 0;
data->newPacket = 1;
data->senderData = senderData;
for (x=0; x<PACKET_ENCRYPT_KEY_SIZE; x++) {
data->dhModulus[x] = PRIMES[rand() % PRIME_COUNT];
data->dhBase[x] = (rand() < (RAND_MAX / 2) ? 2 : 5);
data->dhMySecret[x] = rand();
data->dhMyPublic[x] = packetDHCompute(data->dhBase[x], data->dhMySecret[x], data->dhModulus[x]);
}
}
return data;
}
void packetThreadDataDestroy(PacketThreadDataT **data) {
PacketThreadDataT *d = *data;
free(d);
d = NULL;
*data = d;
}