Skip to content

Commit

Permalink
fix(pkcs11): fix PKCS#11 module initialization and finalization by ma…
Browse files Browse the repository at this point in the history
…king PKCS11CardManager instances singletons per module path

Signed-off-by: Mart Somermaa <[email protected]>
  • Loading branch information
mrts committed May 29, 2023
1 parent ef96b94 commit 6069209
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 46 deletions.
102 changes: 69 additions & 33 deletions src/electronic-ids/pkcs11/PKCS11CardManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
#include <cstring>
#include <string>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <filesystem>
#include <functional>
#include <mutex>

#ifdef _WIN32
#include <Windows.h>
Expand All @@ -61,42 +63,38 @@ namespace electronic_id
class PKCS11CardManager
{
public:
PKCS11CardManager(const std::filesystem::path& module)
/**
* Returns a shared instance of PKCS11CardManager for a given PKCS#11 module.
*
* This method implements a "per-module singleton" pattern: for each distinct module path,
* only one instance of PKCS11CardManager is created. All subsequent requests for that
* module will return a shared pointer to the initially created instance.
*
* This function is thread-safe.
*
* @param module Path to the PKCS11 module.
* @return Shared pointer to the corresponding PKCS11CardManager.
*/
static std::shared_ptr<PKCS11CardManager> instance(const std::filesystem::path& module)
{
CK_C_GetFunctionList C_GetFunctionList = nullptr;
std::string error;
#ifdef _WIN32
library = LoadLibraryW(module.c_str());
if (library) {
C_GetFunctionList = CK_C_GetFunctionList(GetProcAddress(library, "C_GetFunctionList"));
} else {
LPSTR msg = nullptr;
FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS,
nullptr, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
LPSTR(&msg), 0, nullptr);
error = msg;
LocalFree(msg);
}
#else
library = dlopen(module.c_str(), RTLD_LOCAL | RTLD_NOW);
if (library) {
C_GetFunctionList = CK_C_GetFunctionList(dlsym(library, "C_GetFunctionList"));
} else {
error = dlerror();
}
#endif
static std::mutex mutex;
static std::unordered_map<std::string, std::shared_ptr<PKCS11CardManager>> instances;

if (!C_GetFunctionList) {
THROW(SmartCardChangeRequiredError,
"C_GetFunctionList loading failed for module '" + module.string() + "', error "
+ error);
}
Call(__func__, __FILE__, __LINE__, "C_GetFunctionList", C_GetFunctionList, &fl);
if (!fl) {
THROW(SmartCardChangeRequiredError, "C_GetFunctionList: CK_FUNCTION_LIST_PTR is null");
// There is no std::hash for std::filesystem::path, use the string value.
// Note that two different path strings that refer to the same filesystem location
// will be treated as different keys (e.g. /path/to/module and /path/to/../to/module).
std::string moduleStr = module.string();

std::lock_guard<std::mutex> lock(mutex);

auto it = instances.find(moduleStr);
if (it != instances.end()) {
return it->second;
}
C(Initialize, nullptr);

auto newInstance = std::shared_ptr<PKCS11CardManager>(new PKCS11CardManager(module));
instances[moduleStr] = newInstance;
return newInstance;
}

~PKCS11CardManager()
Expand Down Expand Up @@ -233,6 +231,44 @@ class PKCS11CardManager
}

private:
PKCS11CardManager(const std::filesystem::path& module)
{
CK_C_GetFunctionList C_GetFunctionList = nullptr;
std::string error;
#ifdef _WIN32
library = LoadLibraryW(module.c_str());
if (library) {
C_GetFunctionList = CK_C_GetFunctionList(GetProcAddress(library, "C_GetFunctionList"));
} else {
LPSTR msg = nullptr;
FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS,
nullptr, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
LPSTR(&msg), 0, nullptr);
error = msg;
LocalFree(msg);
}
#else
library = dlopen(module.c_str(), RTLD_LOCAL | RTLD_NOW);
if (library) {
C_GetFunctionList = CK_C_GetFunctionList(dlsym(library, "C_GetFunctionList"));
} else {
error = dlerror();
}
#endif

if (!C_GetFunctionList) {
THROW(SmartCardChangeRequiredError,
"C_GetFunctionList loading failed for module '" + module.string() + "', error "
+ error);
}
Call(__func__, __FILE__, __LINE__, "C_GetFunctionList", C_GetFunctionList, &fl);
if (!fl) {
THROW(SmartCardChangeRequiredError, "C_GetFunctionList: CK_FUNCTION_LIST_PTR is null");
}
C(Initialize, nullptr);
}

template <typename Func, typename... Args>
static void Call(const char* function, const char* file, int line, const char* apiFunction,
Func func, Args... args)
Expand Down
18 changes: 9 additions & 9 deletions src/electronic-ids/pkcs11/Pkcs11ElectronicID.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ const std::map<Pkcs11ElectronicIDType, Pkcs11ElectronicIDModule> SUPPORTED_PKCS1
ElectronicID::Type::LitEID, // type
lithuanianPKCS11ModulePath(false).make_preferred(), // path

JsonWebSignatureAlgorithm::RS256, // authSignatureAlgorithm
RSA_SIGNATURE_ALGOS(), // supportedSigningAlgorithms
-1,
JsonWebSignatureAlgorithm::ES384, // authSignatureAlgorithm
ELLIPTIC_CURVE_SIGNATURE_ALGOS(), // supportedSigningAlgorithms
3,
false,
}},
{Pkcs11ElectronicIDType::HrvEID,
Expand Down Expand Up @@ -187,12 +187,12 @@ const Pkcs11ElectronicIDModule& getModule(Pkcs11ElectronicIDType eidType)
Pkcs11ElectronicID::Pkcs11ElectronicID(pcsc_cpp::SmartCard::ptr _card,
Pkcs11ElectronicIDType type) :
ElectronicID(std::move(_card)),
module(getModule(type)), manager(module.path)
module(getModule(type)), manager(PKCS11CardManager::instance(module.path))
{
bool seenAuthToken = false;
bool seenSigningToken = false;

for (const auto& token : manager.tokens()) {
for (const auto& token : manager->tokens()) {
const auto certType = certificateType(token.cert);
if (certType.isAuthentication()) {
authToken = token;
Expand Down Expand Up @@ -229,8 +229,8 @@ pcsc_cpp::byte_vector Pkcs11ElectronicID::signWithAuthKey(const pcsc_cpp::byte_v
validateAuthHashLength(authSignatureAlgorithm(), name(), hash);

const auto signature =
manager.sign(authToken, hash, authSignatureAlgorithm().hashAlgorithm(),
reinterpret_cast<const char*>(pin.data()), pin.size());
manager->sign(authToken, hash, authSignatureAlgorithm().hashAlgorithm(),
reinterpret_cast<const char*>(pin.data()), pin.size());
return signature.first;
} catch (const VerifyPinFailed& e) {
// Catch and rethrow the VerifyPinFailed error with -1 to inform the caller of the special
Expand Down Expand Up @@ -262,8 +262,8 @@ ElectronicID::Signature Pkcs11ElectronicID::signWithSigningKey(const pcsc_cpp::b
validateSigningHash(*this, hashAlgo, hash);

// TODO: add step for supported algo detection before sign(), see if () below.
auto signature = manager.sign(signingToken, hash, hashAlgo,
reinterpret_cast<const char*>(pin.data()), pin.size());
auto signature = manager->sign(signingToken, hash, hashAlgo,
reinterpret_cast<const char*>(pin.data()), pin.size());

if (!module.supportedSigningAlgorithms.count(signature.second)) {
THROW(SmartCardChangeRequiredError,
Expand Down
2 changes: 1 addition & 1 deletion src/electronic-ids/pkcs11/Pkcs11ElectronicID.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Pkcs11ElectronicID : public ElectronicID
Type type() const override { return module.type; }

const Pkcs11ElectronicIDModule& module;
PKCS11CardManager manager;
const std::shared_ptr<PKCS11CardManager> manager;
PKCS11CardManager::Token authToken;
PKCS11CardManager::Token signingToken;
};
Expand Down
6 changes: 3 additions & 3 deletions src/electronic-ids/x509.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inline CertificateType certificateType(const pcsc_cpp::byte_vector& cert)
}
auto keyUsage = SCOPE_GUARD(ASN1_BIT_STRING, extension(x509.get(), NID_key_usage));
if (!keyUsage) {
return electronic_id::CertificateType::NONE;
return CertificateType::NONE;
}

static const int KEY_USAGE_NON_REPUDIATION = 1;
Expand All @@ -48,11 +48,11 @@ inline CertificateType certificateType(const pcsc_cpp::byte_vector& cert)
auto extKeyUsage =
SCOPE_GUARD(EXTENDED_KEY_USAGE, extension(x509.get(), NID_ext_key_usage));
if (extKeyUsage && hasClientAuthExtendedKeyUsage(extKeyUsage.get())) {
return electronic_id::CertificateType::AUTHENTICATION;
return CertificateType::AUTHENTICATION;
}
}

return electronic_id::CertificateType::NONE;
return CertificateType::NONE;
}

} // namespace electronic_id

0 comments on commit 6069209

Please sign in to comment.