diff --git a/auth.c b/auth.c index e7f4c0d..ab5db90 100644 --- a/auth.c +++ b/auth.c @@ -50,18 +50,16 @@ struct lc_auth_ctx * lc_auth_ctx_new(const struct lc_auth_impl *impl) { struct lc_auth_ctx *ctx; - void *arg; ctx = malloc(sizeof(*ctx)); if (ctx == NULL) return NULL; - if (impl->argsz > 0) { - arg = malloc(impl->argsz); - if (arg == NULL) { + if (impl->ctx_new != NULL) { + ctx->arg = impl->ctx_new(NULL); + if (ctx->arg == NULL) { free(ctx); return NULL; } - ctx->arg = arg; } else ctx->arg = NULL; ctx->impl = impl; @@ -72,7 +70,11 @@ lc_auth_ctx_new(const struct lc_auth_impl *impl) void lc_auth_ctx_free(struct lc_auth_ctx *ctx) { - if (ctx != NULL) - free(ctx->arg); - free(ctx); + if (ctx != NULL && ctx->impl != NULL && ctx->impl->ctx_free != NULL) + ctx->impl->ctx_free(ctx); + else { + if (ctx != NULL) + free(ctx->arg); + free(ctx); + } } diff --git a/auth.h b/auth.h index 71e50f4..c5e6a98 100644 --- a/auth.h +++ b/auth.h @@ -19,13 +19,14 @@ struct lc_auth_impl { - int (*init)(void *, const uint8_t *, size_t); - int (*update)(void *, const uint8_t *, size_t); - int (*final)(void *, uint8_t *, size_t *); - int (*auth)(const uint8_t *, size_t, uint8_t *, size_t *, + int (*init)(void *, const uint8_t *, size_t); + int (*update)(void *, const uint8_t *, size_t); + int (*final)(void *, uint8_t *, size_t *); + int (*auth)(const uint8_t *, size_t, uint8_t *, size_t *, const uint8_t *, size_t); - size_t argsz; + void *(*ctx_new)(const void *); + void (*ctx_free)(void *); }; struct lc_auth_ctx { diff --git a/auth_poly1305.c b/auth_poly1305.c index 623022f..9c394fc 100644 --- a/auth_poly1305.c +++ b/auth_poly1305.c @@ -14,6 +14,8 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include + #include "lilcrypto.h" #include "auth.h" #include "auth_poly1305.h" @@ -151,6 +153,11 @@ poly1305_auth(const uint8_t *key, size_t keylen, uint8_t *out, size_t *outlen, poly1305_final(&ctx, out, outlen); } +static void * +poly1305_ctx_new(const void *arg) +{ + return malloc(sizeof(struct poly1305_ctx)); +} static struct lc_auth_impl poly1305_impl = { .init = &poly1305_init, @@ -158,7 +165,8 @@ static struct lc_auth_impl poly1305_impl = { .final = &poly1305_final, .auth = &poly1305_auth, - .argsz = sizeof(struct poly1305_ctx), + .ctx_new = &poly1305_ctx_new, + .ctx_free = NULL, }; const struct lc_auth_impl * diff --git a/cipher.c b/cipher.c index 127075b..a2c4a25 100644 --- a/cipher.c +++ b/cipher.c @@ -82,18 +82,16 @@ struct lc_cipher_ctx * lc_cipher_ctx_new(const struct lc_cipher_impl *impl) { struct lc_cipher_ctx *ctx; - void *arg; ctx = malloc(sizeof(*ctx)); if (ctx == NULL) return NULL; - if (impl->argsz > 0) { - arg = malloc(impl->argsz); - if (arg == NULL) { + if (impl->ctx_new != NULL) { + ctx->arg = impl->ctx_new(NULL); + if (ctx->arg == NULL) { free(ctx); return NULL; } - ctx->arg = arg; } else ctx->arg = NULL; ctx->impl = impl; @@ -104,7 +102,11 @@ lc_cipher_ctx_new(const struct lc_cipher_impl *impl) void lc_cipher_ctx_free(struct lc_cipher_ctx *ctx) { - if (ctx != NULL) - free(ctx->arg); - free(ctx); + if (ctx != NULL && ctx->impl != NULL && ctx->impl->ctx_free != NULL) + ctx->impl->ctx_free(ctx); + else { + if (ctx != NULL) + free(ctx->arg); + free(ctx); + } } diff --git a/cipher.h b/cipher.h index 73e8a8e..d20c4cb 100644 --- a/cipher.h +++ b/cipher.h @@ -19,23 +19,24 @@ struct lc_cipher_impl { - int (*encrypt_init)(void *, const uint8_t *, size_t, + int (*encrypt_init)(void *, const uint8_t *, size_t, const uint8_t *, size_t); - int (*encrypt_update)(void *, uint8_t *, size_t *, const uint8_t *, + int (*encrypt_update)(void *, uint8_t *, size_t *, const uint8_t *, size_t); - int (*encrypt_final)(void *, uint8_t *, size_t *); - int (*encrypt)(const uint8_t *, size_t, const uint8_t *, size_t, + int (*encrypt_final)(void *, uint8_t *, size_t *); + int (*encrypt)(const uint8_t *, size_t, const uint8_t *, size_t, uint8_t *, size_t *, const uint8_t *, size_t); - int (*decrypt_init)(void *, const uint8_t *, size_t, + int (*decrypt_init)(void *, const uint8_t *, size_t, const uint8_t *, size_t); - int (*decrypt_update)(void *, uint8_t *, size_t *, const uint8_t *, + int (*decrypt_update)(void *, uint8_t *, size_t *, const uint8_t *, size_t); - int (*decrypt_final)(void *, uint8_t *, size_t *); - int (*decrypt)(const uint8_t *, size_t, const uint8_t *, size_t, + int (*decrypt_final)(void *, uint8_t *, size_t *); + int (*decrypt)(const uint8_t *, size_t, const uint8_t *, size_t, uint8_t *, size_t *, const uint8_t *, size_t); - size_t argsz; + void *(*ctx_new)(const void *); + void (*ctx_free)(void *); }; struct lc_cipher_ctx { diff --git a/cipher_chacha20.c b/cipher_chacha20.c index 31102a4..5e09375 100644 --- a/cipher_chacha20.c +++ b/cipher_chacha20.c @@ -15,6 +15,7 @@ */ #include +#include #include "lilcrypto.h" #include "cipher.h" @@ -177,6 +178,12 @@ chacha20_x(const uint8_t *key, size_t keylen, const uint8_t *iv, size_t ivlen, return rc; } +static void * +chacha20_ctx_new(const void *arg) +{ + return malloc(sizeof(struct chacha20_ctx)); +} + static struct lc_cipher_impl chacha20_impl = { .encrypt_init = &chacha20_x_init, @@ -189,7 +196,8 @@ static struct lc_cipher_impl chacha20_impl = { .decrypt_final = &chacha20_x_final, .decrypt = &chacha20_x, - .argsz = sizeof(struct chacha20_ctx), + .ctx_new = &chacha20_ctx_new, + .ctx_free = NULL, }; const struct lc_cipher_impl *