diff --git a/include/media_stream.hh b/include/media_stream.hh index e553f89..56ce48d 100644 --- a/include/media_stream.hh +++ b/include/media_stream.hh @@ -16,10 +16,13 @@ namespace uvgrtp { // forward declarations class rtp; - class zrtp; class rtcp; + + class zrtp; + class base_srtp; class srtp; class srtcp; + class pkt_dispatcher; class holepuncher; class socket; @@ -307,6 +310,9 @@ namespace uvgrtp { /* free all allocated resources */ rtp_error_t free_resources(rtp_error_t ret); + rtp_error_t init_srtp_with_zrtp(int flags, int type, uvgrtp::base_srtp* srtp, + uvgrtp::zrtp *zrtp); + uint32_t key_; uvgrtp::srtp *srtp_; diff --git a/src/media_stream.cc b/src/media_stream.cc index 6503271..1dde156 100644 --- a/src/media_stream.cc +++ b/src/media_stream.cc @@ -172,14 +172,47 @@ rtp_error_t uvgrtp::media_stream::create_media(rtp_format_t fmt) rtp_error_t uvgrtp::media_stream::free_resources(rtp_error_t ret) { - delete socket_; - delete rtcp_; - delete rtp_; - delete srtp_; - delete srtcp_; - delete pkt_dispatcher_; - delete holepuncher_; - delete media_; + if (socket_) + { + delete socket_; + socket_ = nullptr; + } + if (rtcp_) + { + delete rtcp_; + rtcp_ = nullptr; + } + if (rtp_) + { + delete rtp_; + rtp_ = nullptr; + } + if (srtp_) + { + delete srtp_; + srtp_ = nullptr; + } + if (srtcp_) + { + delete srtcp_; + srtcp_ = nullptr; + } + if (pkt_dispatcher_) + { + delete pkt_dispatcher_; + pkt_dispatcher_ = nullptr; + } + if (holepuncher_) + { + delete holepuncher_; + holepuncher_ = nullptr; + } + if (media_) + { + delete media_; + media_ = nullptr; + } + return ret; } @@ -237,18 +270,12 @@ rtp_error_t uvgrtp::media_stream::init(uvgrtp::zrtp *zrtp) } srtp_ = new uvgrtp::srtp(); - - if ((ret = srtp_->init_zrtp(SRTP, ctx_config_.flags, zrtp)) != RTP_OK) { - LOG_WARN("Failed to initialize SRTP for media stream!"); - return free_resources(ret); - } + if ((ret = init_srtp_with_zrtp(ctx_config_.flags, SRTP, srtp_, zrtp)) != RTP_OK) + return free_resources(ret); srtcp_ = new uvgrtp::srtcp(); - - if ((ret = srtcp_->init_zrtp(SRTCP, ctx_config_.flags, zrtp)) != RTP_OK) { - LOG_ERROR("Failed to initialize SRTCP for media stream!"); - return free_resources(ret); - } + if ((ret = init_srtp_with_zrtp(ctx_config_.flags, SRTCP, srtcp_, zrtp)) != RTP_OK) + return free_resources(ret); rtcp_ = new uvgrtp::rtcp(rtp_, srtcp_, ctx_config_.flags); @@ -303,14 +330,15 @@ rtp_error_t uvgrtp::media_stream::add_srtp_ctx(uint8_t *key, uint8_t *salt) srtp_ = new uvgrtp::srtp(); - if ((ret = srtp_->init_user(SRTP, ctx_config_.flags, key, salt)) != RTP_OK) { + // why are they local and remote key/salt the same? + if ((ret = srtp_->init(SRTP, ctx_config_.flags, key, key, salt, salt)) != RTP_OK) { LOG_WARN("Failed to initialize SRTP for media stream!"); return free_resources(ret); } srtcp_ = new uvgrtp::srtcp(); - if ((ret = srtcp_->init_user(SRTCP, ctx_config_.flags, key, salt)) != RTP_OK) { + if ((ret = srtcp_->init(SRTCP, ctx_config_.flags, key, key, salt, salt)) != RTP_OK) { LOG_WARN("Failed to initialize SRTCP for media stream!"); return free_resources(ret); } @@ -572,3 +600,36 @@ uvgrtp::rtcp *uvgrtp::media_stream::get_rtcp() { return rtcp_; } + +rtp_error_t uvgrtp::media_stream::init_srtp_with_zrtp(int flags, int type, uvgrtp::base_srtp* srtp, + uvgrtp::zrtp *zrtp) +{ + size_t key_size = srtp->get_key_size(flags); + + uint8_t* local_key = new uint8_t[key_size]; + uint8_t* remote_key = new uint8_t[key_size]; + uint8_t local_salt[UVG_SALT_LENGTH]; + uint8_t remote_salt[UVG_SALT_LENGTH]; + + rtp_error_t ret = zrtp->get_srtp_keys( + local_key, key_size * 8, + remote_key, key_size * 8, + local_salt, UVG_SALT_LENGTH * 8, + remote_salt, UVG_SALT_LENGTH * 8 + ); + + if (ret == RTP_OK) + { + ret = srtp->init(type, flags, local_key, remote_key, + local_salt, remote_salt); + } + else + { + LOG_WARN("Failed to initialize SRTP for media stream!"); + } + + delete[] local_key; + delete[] remote_key; + + return ret; +} diff --git a/src/srtp/base.cc b/src/srtp/base.cc index e57b92f..4ac1a30 100644 --- a/src/srtp/base.cc +++ b/src/srtp/base.cc @@ -1,7 +1,6 @@ #include "base.hh" #include "crypto.hh" -#include "../zrtp.hh" #include "debug.hh" #include @@ -101,13 +100,20 @@ bool uvgrtp::base_srtp::is_replayed_packet(uint8_t *digest) return false; } -rtp_error_t uvgrtp::base_srtp::init(int type, int flags, size_t key_size) +rtp_error_t uvgrtp::base_srtp::init(int type, int flags, uint8_t* local_key, uint8_t* remote_key, + uint8_t* local_salt, uint8_t* remote_salt) { srtp_ctx_->roc = 0; srtp_ctx_->rts = 0; srtp_ctx_->type = type; srtp_ctx_->hmac = HMAC_SHA1; + size_t key_size = get_key_size(flags); + + rtp_error_t ret = RTP_OK; + if ((ret = set_master_keys(key_size, local_key, remote_key, local_salt, remote_salt)) != RTP_OK) + return ret; + switch (key_size) { case AES128_KEY_SIZE: srtp_ctx_->enc = AES_128; @@ -201,7 +207,7 @@ rtp_error_t uvgrtp::base_srtp::init(int type, int flags, size_t key_size) UVG_SALT_LENGTH ); - return RTP_OK; + return ret; } rtp_error_t uvgrtp::base_srtp::allocate_crypto_ctx(size_t key_size) @@ -215,53 +221,6 @@ rtp_error_t uvgrtp::base_srtp::allocate_crypto_ctx(size_t key_size) return RTP_OK; } -rtp_error_t uvgrtp::base_srtp::init_zrtp(int type, int flags, uvgrtp::zrtp *zrtp) -{ - if (!zrtp) - return RTP_INVALID_VALUE; - - size_t key_size = get_key_size(flags); - - uint8_t* local_key = new uint8_t[key_size]; - uint8_t* remote_key = new uint8_t[key_size]; - uint8_t local_salt[UVG_SALT_LENGTH]; - uint8_t remote_salt[UVG_SALT_LENGTH]; - - /* ZRTP key derivation function expects the keys lengths to be given in bits */ - rtp_error_t ret = zrtp->get_srtp_keys( - local_key, key_size * 8, - remote_key, key_size * 8, - local_salt, UVG_SALT_LENGTH * 8, - remote_salt, UVG_SALT_LENGTH * 8 - ); - - if (ret == RTP_OK) - ret = set_master_keys(flags, local_key, remote_key, local_salt, remote_salt); - - delete[] local_key; - delete[] remote_key; - - if (ret != RTP_OK) { - LOG_ERROR("Failed to derive keys for SRTP session!"); - return ret; - } - - return init(type, flags, key_size); -} - -rtp_error_t uvgrtp::base_srtp::init_user(int type, int flags, uint8_t *key, uint8_t *salt) -{ - if (!key || !salt) - return RTP_INVALID_VALUE; - - rtp_error_t ret = RTP_OK; - if ((ret = set_master_keys(flags, key, key, salt, salt)) != RTP_OK) - return ret; - - size_t key_size = get_key_size(flags); - return init(type, flags, key_size); -} - size_t uvgrtp::base_srtp::get_key_size(int flags) { size_t key_size = AES128_KEY_SIZE; @@ -277,14 +236,12 @@ size_t uvgrtp::base_srtp::get_key_size(int flags) return key_size; } -rtp_error_t uvgrtp::base_srtp::set_master_keys(int flags, uint8_t* local_key, uint8_t* remote_key, +rtp_error_t uvgrtp::base_srtp::set_master_keys(size_t key_size, uint8_t* local_key, uint8_t* remote_key, uint8_t* local_salt, uint8_t* remote_salt) { if (!local_key || !remote_key || !local_salt || !remote_salt) return RTP_INVALID_VALUE; - size_t key_size = get_key_size(flags); - rtp_error_t ret = RTP_OK; if ((ret = allocate_crypto_ctx(key_size)) != RTP_OK) return ret; diff --git a/src/srtp/base.hh b/src/srtp/base.hh index 8d8dd0e..5125bd5 100644 --- a/src/srtp/base.hh +++ b/src/srtp/base.hh @@ -33,11 +33,6 @@ enum { namespace uvgrtp { - - - class zrtp; - class rtp; - /* Vector of buffers that contain a full RTP frame */ typedef std::vector> buf_vec; @@ -131,14 +126,7 @@ namespace uvgrtp { base_srtp(); virtual ~base_srtp(); - /* Setup Secure RTP/RTCP connection using ZRTP - * - * Return RTP_OK if SRTP setup was successful - * Return RTP_INVALID_VALUE if "zrtp" is nullptr - * Return RTP_MEMORY allocation failed */ - rtp_error_t init_zrtp(int type, int flags, uvgrtp::zrtp *zrtp); - - /* Setup Secure RTP/RTCP connection using user-managed keys + /* Setup Secure RTP/RTCP connection with master keys * * Length of the "key" must be either 128, 192, or 256 bits * Length of "salt" must be SALT_LENGTH (14 bytes, 112 bits) @@ -146,7 +134,8 @@ namespace uvgrtp { * Return RTP_OK if SRTP setup was successful * Return RTP_INVALID_VALUE if "key" or "salt" is nullptr * Return RTP_MEMORY allocation failed */ - rtp_error_t init_user(int type, int flags, uint8_t *key, uint8_t *salt); + rtp_error_t init(int type, int flags, uint8_t *local_key, uint8_t *remote_key, + uint8_t *local_salt, uint8_t *remote_salt); /* Has RTP packet encryption been disabled? */ bool use_null_cipher(); @@ -161,8 +150,8 @@ namespace uvgrtp { * Returns false if replay protection has not been enabled */ bool is_replayed_packet(uint8_t *digest); - rtp_error_t set_master_keys(int flags, uint8_t* local_key, uint8_t* remote_key, - uint8_t* local_salt, uint8_t* remote_salt); + size_t get_key_size(int flags); + protected: /* Create IV for the packet that is about to be encrypted @@ -179,15 +168,14 @@ namespace uvgrtp { bool use_null_cipher_; private: - /* Internal init method that initialize the SRTP context using values in key_ctx_.master */ - rtp_error_t init(int type, int flags, size_t key_size); + rtp_error_t set_master_keys(size_t key_size, uint8_t* local_key, uint8_t* remote_key, + uint8_t* local_salt, uint8_t* remote_salt); rtp_error_t derive_key(int label, uint8_t *key, uint8_t *salt, uint8_t *out, size_t len); /* Allocate space for master/session encryption keys */ rtp_error_t allocate_crypto_ctx(size_t key_size); - size_t get_key_size(int flags); /* By default RTP packet authentication is disabled but by * giving RCE_SRTP_AUTHENTICATE_RTP to create_stream() user can enable it.