diff options
Diffstat (limited to 'src/core/hle')
| -rw-r--r-- | src/core/hle/service/ssl/ssl_backend_securetransport.cpp | 219 | 
1 files changed, 219 insertions, 0 deletions
| diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100644 index 000000000..be40a5aeb --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include <mutex> + +#include <Security/SecureTransport.h> + +// SecureTransport has been deprecated in its entirety in favor of +// Network.framework, but that does not allow layering TLS on top of an +// arbitrary socket. +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +namespace { + +template <typename T> +struct CFReleaser { +    T ptr; + +    YUZU_NON_COPYABLE(CFReleaser); +    constexpr CFReleaser() : ptr(nullptr) {} +    constexpr CFReleaser(T ptr) : ptr(ptr) {} +    constexpr operator T() { +        return ptr; +    } +    ~CFReleaser() { +        if (ptr) { +            CFRelease(ptr); +        } +    } +}; + +std::string CFStringToString(CFStringRef cfstr) { +    CFReleaser<CFDataRef> cfdata( +        CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); +    ASSERT_OR_EXECUTE(cfdata, { return "???"; }); +    return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), +                       CFDataGetLength(cfdata)); +} + +std::string OSStatusToString(OSStatus status) { +    CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); +    if (!cfstr) { +        return "[unknown error]"; +    } +    return CFStringToString(cfstr); +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { +public: +    Result Init() { +        static std::once_flag once_flag; +        std::call_once(once_flag, []() { +            if (getenv("SSLKEYLOGFILE")) { +                LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " +                                          "support exporting keys; not logging keys!"); +                // Not fatal. +            } +        }); + +        context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); +        if (!context) { +            LOG_ERROR(Service_SSL, "SSLCreateContext failed"); +            return ResultInternalError; +        } + +        OSStatus status; +        if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || +            (status = SSLSetConnection(context, this))) { +            LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", +                      OSStatusToString(status)); +            return ResultInternalError; +        } + +        return ResultSuccess; +    } + +    void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { +        socket = std::move(in_socket); +    } + +    Result SetHostName(const std::string& hostname) override { +        OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); +        if (status) { +            LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); +            return ResultInternalError; +        } +        return ResultSuccess; +    } + +    Result DoHandshake() override { +        OSStatus status = SSLHandshake(context); +        return HandleReturn("SSLHandshake", 0, status).Code(); +    } + +    ResultVal<size_t> Read(std::span<u8> data) override { +        size_t actual; +        OSStatus status = SSLRead(context, data.data(), data.size(), &actual); +        ; +        return HandleReturn("SSLRead", actual, status); +    } + +    ResultVal<size_t> Write(std::span<const u8> data) override { +        size_t actual; +        OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); +        ; +        return HandleReturn("SSLWrite", actual, status); +    } + +    ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { +        switch (status) { +        case 0: +            return actual; +        case errSSLWouldBlock: +            return ResultWouldBlock; +        default: { +            std::string reason; +            if (got_read_eof) { +                reason = "server hung up"; +            } else { +                reason = OSStatusToString(status); +            } +            LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); +            return ResultInternalError; +        } +        } +    } + +    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { +        CFReleaser<SecTrustRef> trust; +        OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); +        if (status) { +            LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); +            return ResultInternalError; +        } +        std::vector<std::vector<u8>> ret; +        for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { +            SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); +            CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); +            ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); +            const u8* ptr = CFDataGetBytePtr(data); +            ret.emplace_back(ptr, ptr + CFDataGetLength(data)); +        } +        return ret; +    } + +    static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { +        return ReadOrWriteCallback(connection, data, dataLength, true); +    } + +    static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, +                                  size_t* dataLength) { +        return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); +    } + +    static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, +                                        bool is_read) { +        auto self = +            static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); +        ASSERT_OR_EXECUTE_MSG( +            self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", +            is_read ? "read" : "write"); + +        // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are +        // expected to read/write the full requested dataLength or return an +        // error, so we have to add a loop ourselves. +        size_t requested_len = *dataLength; +        size_t offset = 0; +        while (offset < requested_len) { +            std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); +            auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); +            LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, +                         actual, cur.size(), static_cast<s32>(err)); +            switch (err) { +            case Network::Errno::SUCCESS: +                offset += actual; +                if (actual == 0) { +                    ASSERT(is_read); +                    self->got_read_eof = true; +                    return errSecEndOfData; +                } +                break; +            case Network::Errno::AGAIN: +                *dataLength = offset; +                return errSSLWouldBlock; +            default: +                LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", +                          is_read ? "recv" : "send", err); +                return errSecIO; +            } +        } +        ASSERT(offset == requested_len); +        return 0; +    } + +private: +    CFReleaser<SSLContextRef> context = nullptr; +    bool got_read_eof = false; + +    std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { +    auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); +    const Result res = conn->Init(); +    if (res.IsFailure()) { +        return res; +    } +    return conn; +} + +} // namespace Service::SSL | 
