/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2020-2025 Brett Sheffield <bacs@librecast.net> */

#include "key.h"
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <sodium.h>
#include <stdio.h>
#include <string.h>
#include <sys/stat.h>
#include <unistd.h>

/* extract public encryption key from combo hex key */
int key_combo_hex2pek(unsigned char *pek, size_t len, char *phex_combo)
{
	return sodium_hex2bin(pek, len, phex_combo, HEXKEY_PUBLIC_CRYPTBYTES, NULL, NULL, NULL);
}

/* extract public signing key from combo hex key */
int key_combo_hex2psk(unsigned char *psk, size_t len, char *phex_combo)
{
	char *pskhex = phex_combo + HEXKEY_PUBLIC_CRYPTBYTES;
	return sodium_hex2bin(psk, len, pskhex, HEXKEY_PUBLIC_SIGNBYTES, NULL, NULL, NULL);
}

/* extract secret encryption key from combo hex key */
int key_combo_hex2sek(unsigned char *sek, size_t len, char *shex_combo)
{
	return sodium_hex2bin(sek, len, shex_combo, HEXKEY_SECRET_CRYPTBYTES, NULL, NULL, NULL);
}

/* extract secret signing key from combo hex key */
int key_combo_hex2ssk(unsigned char *ssk, size_t len, char *shex_combo)
{
	char *sskhex = shex_combo + HEXKEY_SECRET_CRYPTBYTES;
	return sodium_hex2bin(ssk, len, sskhex, HEXKEY_SECRET_SIGNBYTES, NULL, NULL, NULL);
}

/* generate a combo keypair
 * the resulting encryption and signing keys are concatenated */
int key_gen_keys(key_combo_t *keyring)
{
	if (lc_keypair_new(&keyring->s, LC_KEY_ENC) == -1) return -1;
	return lc_keypair_new(&keyring->e, LC_KEY_SIG);
}

int key_gen_combopair_hex(key_combo_t *keyring)
{
	char *phex_sign = keyring->phex + HEXKEY_PUBLIC_CRYPTBYTES;
	char *shex_sign = keyring->shex + HEXKEY_SECRET_CRYPTBYTES;
	if (key_gen_keys(keyring)) return -1;
	sodium_bin2hex(keyring->phex, HEXKEY_PUBLIC_CRYPTBYTES + 1, keyring->e.pk, crypto_box_PUBLICKEYBYTES);
	sodium_bin2hex(phex_sign, HEXKEY_PUBLIC_SIGNBYTES + 1, keyring->s.pk, crypto_sign_PUBLICKEYBYTES);
	sodium_bin2hex(keyring->shex, HEXKEY_SECRET_CRYPTBYTES + 1, keyring->e.sk, crypto_box_SECRETKEYBYTES);
	sodium_bin2hex(shex_sign, HEXKEY_SECRET_SIGNBYTES + 1, keyring->s.sk, crypto_sign_SECRETKEYBYTES);
	return 0;
}

char *key_config_dir(char *home, size_t *len)
{
	char *str;
	int rc;
	rc = snprintf(NULL, 0, "%s/%s", home, KEY_PATH);
	if (rc < 0) return NULL;
	*len = (size_t)rc + 1;
	str = malloc(*len);
	if (!str) return NULL;
	rc = snprintf(str, *len, "%s/%s", home, KEY_PATH);
	if ((size_t)rc != *len - 1) {
		free(str);
		return NULL;
	}
	return str;
}

int key_gen_combopair_write(char *home, key_combo_t *keyring)
{
	char *str;
	size_t len;
	int fd, err, rc;

	/* create config directory */
	rc = snprintf(NULL, 0, "mkdir -p %s/%s", home, KEY_PATH);
	if (rc < 0) return -1;
	len = (size_t)rc + 1;
	str = malloc(len);
	if (str == NULL) return -1;
	rc = snprintf(str, len, "mkdir -p %s/%s", home, KEY_PATH);
	rc = system(str);
	if (rc == -1) goto err_free_str;

	/* generate keypair */
	rc = key_gen_combopair_hex(keyring);
	if (rc == -1) goto err_free_str;
	free(str); str = NULL;

	/* create file for writing, named for public key */
	rc = snprintf(NULL, 0, "%s/%s/%s", home, KEY_PATH, keyring->phex);
	len = rc + 1;
	str = malloc(len);
	if (str == NULL) return -1;
	rc = snprintf(str, len, "%s/%s/%s", home, KEY_PATH, keyring->phex);
	fd = open(str, O_CREAT|O_EXCL|O_WRONLY|O_SYNC, S_IRUSR|S_IWUSR);
	if (fd == -1) goto err_free_str;

	/* write secret key to file */
	ssize_t byt = write(fd, keyring->shex, KEY_SECRET_HEXLEN);
	close(fd);
	if ((size_t)byt != KEY_SECRET_HEXLEN) goto err_free_str;

	/* create default symlink if not exist */
	char *linkpath;
	rc = snprintf(NULL, 0, "%s/%s/default", home, KEY_PATH);
	if (rc == -1) goto err_free_str;
	len = (size_t)rc + 1;
	linkpath = malloc(len);
	if (linkpath == NULL) {
		rc = -1;
		goto err_free_str;
	}
	rc = snprintf(linkpath, len, "%s/%s/default", home, KEY_PATH);
	if (rc == -1) goto err_free_linkpath;
	rc = symlink(keyring->phex, linkpath);
	if (rc == -1 && errno == EEXIST) rc = 0;
err_free_linkpath:
	free(linkpath);
err_free_str:
	err = errno;
	free(str);
	errno = err;
	return rc;
}

int key_combo_hex_decode(key_combo_t *keyring)
{
	if (key_combo_hex2pek(keyring->e.pk, sizeof keyring->e.pk, keyring->phex))
		return -1;
	if (key_combo_hex2psk(keyring->s.pk, sizeof keyring->s.pk, keyring->phex))
		return -1;
	if (key_combo_hex2psk(keyring->s.sk, sizeof keyring->s.pk, keyring->phex))
		return -1;
	if (key_combo_hex2sek(keyring->e.sk, sizeof keyring->e.sk, keyring->shex))
		return -1;
	return key_combo_hex2ssk(keyring->s.sk, sizeof keyring->s.sk, keyring->shex);
}

int key_load_default(state_t *state, key_combo_t *keyring)
{
	char *configdir;
	char *keyfile;
	size_t len;
	int rc = -1;
	configdir = key_config_dir(state->dir_home, &len);
	if (configdir == NULL) return -1;
	len += strlen(KEY_DEFAULT) + 1;
	keyfile = malloc(len);
	if (keyfile == NULL) goto err_free_configdir;
	strcpy(keyfile, configdir);
	strcat(keyfile, "/");
	strcat(keyfile, KEY_DEFAULT);
	int fd;
	fd = open(keyfile, O_RDONLY);
	if (fd == -1) {
		if (errno == ENOENT) {
			/* no key, create it */
			free(keyfile);
			free(configdir);
			E(STATE_VERBOSE, "generating keys\n");
			return key_gen_combopair_write(state->dir_home, keyring);
		}
		else goto err_free_keyfile;
	}
	else {
		/* load keys */
		E(STATE_VERBOSE, "loading keys from %s\n", keyfile);
		memset(keyring->phex, 0, KEY_PUBLIC_HEXLEN + 1);
		rc = readlink(keyfile, keyring->phex, KEY_PUBLIC_HEXLEN);
		if (rc != KEY_PUBLIC_HEXLEN) goto err_close_fd;
		rc = read(fd, keyring->shex, KEY_SECRET_HEXLEN);
		if (rc == KEY_SECRET_HEXLEN) rc = key_combo_hex_decode(keyring);
	}
err_close_fd:
	close(fd);
err_free_keyfile:
	free(keyfile);
err_free_configdir:
	free(configdir);
	return rc;
}

char *key_cap_path(state_t *state, unsigned char *chanhash, size_t len)
{
	char *pathname;
	char hexchan[len * 2 + 1];
	int sz = sizeof hexchan;
	int rc;

	if (sodium_init() == -1) return NULL;
	sodium_bin2hex(hexchan, sizeof hexchan, chanhash, len);

	rc = snprintf(NULL, 0, "%s/%.*s.token", state->dir_state, sz, hexchan);
	if (rc < 0) return NULL;
	rc++;
	pathname = malloc(rc);
	if (!pathname) return NULL;
	rc = snprintf(pathname, rc, "%s/%.*s.token", state->dir_state, sz, hexchan);
	return pathname;
}

int key_cap_load(lc_token_t *token, state_t *state, unsigned char *chanhash, size_t len)
{
	char *pathname;
	int fd;
	int rc = -1;

	pathname = key_cap_path(state, chanhash, len);
	if (!pathname) return -1;
	fd = open(pathname, O_RDONLY);
	if (fd >= 0) {
		ssize_t byt = read(fd, token, sizeof *token);
		rc = (byt == sizeof *token) ? 0 : -1;
		close(fd);
	}
	free(pathname);
	return rc;
}

/* save the token to the state directory, named CHANNELHEX.token,
 * where CHANNELHEX is the hex value of the token's channel hash */
static int key_cap_save(state_t *state, lc_token_t *token, int flags)
{
	char *pathname;
	int fd;
	int rc = -1;

	pathname = key_cap_path(state, token->channel, sizeof token->channel);
	if (!pathname) return -1;
	fd = creat(pathname, S_IRUSR | S_IWUSR);
	if (fd >= 0) {
		ssize_t byt = write(fd, token, sizeof *token);
		rc = (byt == sizeof *token) ? 0 : -1;
		close(fd);
	}
	if (flags & KEY_CAP_PATH) {
		fprintf(stderr, "%s\n", pathname);
	}
	free(pathname);

	return rc;
}

int key_cap_issue(state_t *state, lc_keypair_t *signing_key, uint8_t *bearer_key,
		lc_channel_t *chan, uint8_t capbits, uint64_t valid_sec, int flags)
{
	lc_token_t token = {0};
	int rc;

	/* sign token */
	rc = lc_token_new(&token, signing_key, bearer_key, chan, capbits, valid_sec);
	if (rc == -1) return -1;

	if (flags & KEY_CAP_SAVE) {
		rc = key_cap_save(state, &token, flags);
		if (rc) return rc;
	}
	if (flags & KEY_CAP_SEND) {
		/* TODO send token */
		return (errno = ENOSYS), -1;
	}
	return rc;
}

static char *keyring_path(state_t *state)
{
	char *pathname;
	int rc;
	if (!state->dir_state) return NULL;
	rc = snprintf(NULL, 0, "%s/%s", state->dir_state, KEY_AUTH_FILE);
	if (rc < 0) return NULL;
	rc++;
	pathname = malloc(rc);
	if (!pathname) return NULL;
	rc = snprintf(pathname, rc, "%s/%s", state->dir_state, KEY_AUTH_FILE);
	return pathname;
}

static FILE *keyring_open(state_t *state)
{
	FILE *f;
	char *pathname;
	pathname = keyring_path(state);
	if (!pathname) return NULL;
	E(STATE_VERBOSE, "opening keyring at '%s'\n", pathname);
	f = fopen(pathname, "a+");
	free(pathname);
	if (!f) return NULL;
	return f;
}

void keyring_freekeys(lc_keyring_t *keyring)
{
	for (size_t i = 0; i < keyring->nkeys; i++) free(keyring->key[i]);
}

int keyring_add(state_t *state, char *hexkey)
{
	char keybuf[KEY_PUBLIC_HEXLEN + 1];
	FILE *f;
	size_t byt;
	int rc;

	f = keyring_open(state);
	if (!f) {
		perror("keyring_open");
		return -1;
	}
	rewind(f);
	while ((byt = fread(keybuf, sizeof keybuf, 1, f))) {
		if (strncmp(keybuf, hexkey, KEY_PUBLIC_HEXLEN) != 0) continue;
		fclose(f);
		return (errno = EEXIST), -1; /* key already in file */
	}
	rc = fseek(f, 0, SEEK_END);
	if (rc == 0) rc = fprintf(f, "%s\n", hexkey);
	fclose(f);

	return (rc < 0) ? -1 : 0;
}

int keyring_load(state_t *state, lc_keyring_t *keyring)
{
	char keybuf[KEY_PUBLIC_HEXLEN + 1];
	uint8_t *psk;
	FILE *f;
	size_t byt;
	int nkeys = 0;
	int rc = 0;

	f = keyring_open(state);
	if (!f) return -1;
	rewind(f);
	while ((byt = fread(keybuf, sizeof keybuf, 1, f))) nkeys++;
	memset(keyring, 0, sizeof *keyring);
	if (nkeys) {
		rc = lc_keyring_init(keyring, nkeys);
		rewind(f);
		while (!rc && (byt = fread(keybuf, sizeof keybuf, 1, f))) {
			keybuf[KEY_PUBLIC_HEXLEN] = 0;
			psk = malloc(crypto_sign_PUBLICKEYBYTES);
			if (!psk) {rc = -1; break; }
			rc = key_combo_hex2psk(psk, crypto_sign_PUBLICKEYBYTES, keybuf);
			if (rc == -1) break;
			rc = lc_keyring_add(keyring, psk);
			if (rc == -1) break;
		}
	}
	fclose(f);
	return rc;
}

int keyring_del(state_t *state, char *hexkey)
{
	char keybuf[KEY_PUBLIC_HEXLEN + 1];
	FILE *f;
	struct stat sb;
	char *pathname;
	char *fbuf = NULL;
	size_t fbufsz;
	size_t byt;
	int nkeys = 0;
	int rc;

	pathname = keyring_path(state);
	if (!pathname) return -1;
	f = fopen(pathname, "r+");
	free(pathname);
	if (!f) return -1;
	rc = fstat(fileno(f), &sb);
	if (rc == -1) goto err_fclose;
	while ((byt = fread(keybuf, sizeof keybuf, 1, f))) {
		if (strncmp(keybuf, hexkey, KEY_PUBLIC_HEXLEN) == 0) {
			rc = -1;
			fbufsz = sb.st_size - (nkeys + 1) * sizeof keybuf;
			if (fbufsz) {
				/* load remaining keys into buffer */
				fbuf = malloc(fbufsz);
				if (!fbuf) goto err_fclose;
				byt = fread(fbuf, fbufsz, 1, f);
				if (byt != 1) break;
				rc = fseek(f, nkeys * sizeof keybuf, SEEK_SET);
				if (rc == -1) break;
			}
			/* truncate file */
			byt = sb.st_size - sizeof keybuf;
			rc = ftruncate(fileno(f), byt);
			if (rc == -1) break;
			if (fbuf) {
				/* write tail keys to new file position */
				byt = fwrite(fbuf, fbufsz, 1, f);
				if (byt != 1) break;
			}
			rc = 0;
			break;
		}
		else nkeys++;
	}
	if (fbuf) free(fbuf);
err_fclose:
	fclose(f);
	return rc;
}

int keyring_show(state_t *state)
{
	char keybuf[KEY_PUBLIC_HEXLEN + 1] = {0};
	FILE *f;
	struct stat sb;
	char *pathname;
	size_t byt;
	int nkeys = 0;
	int rc;

	pathname = keyring_path(state);
	if (!pathname) return -1;
	f = fopen(pathname, "r+");
	free(pathname);
	if (!f) return -1;
	rc = fstat(fileno(f), &sb);
	if (rc == -1) goto err_fclose;
	size_t off = 0;
	while ((byt = fread(keybuf + off, 1, sizeof keybuf - off, f))) {
		if ((byt+off) < KEY_PUBLIC_HEXLEN) {
			off += byt;
			continue;
		}
		keybuf[byt + off] = '\0';
		printf("%s", keybuf);
		nkeys++;
		off = 0;
	}
	fflush(stdout);
	fprintf(stderr, "%i authorized key(s) found\n", nkeys);
err_fclose:
	fclose(f);
	return rc;
}
