Skip to content

Commit

Permalink
remote: re-connect when server is disconnected
Browse files Browse the repository at this point in the history
  • Loading branch information
osy committed Feb 25, 2024
1 parent 0a8bff6 commit 4dca247
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 28 deletions.
44 changes: 41 additions & 3 deletions Platform/UTMData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ struct AlertMessage: Identifiable {
virtualMachines.forEach({ endObservingChanges(for: $0) })
virtualMachines = vms
vms.forEach({ beginObservingChanges(for: $0) })
selectedVM = nil
}

/// Add VM to list
Expand Down Expand Up @@ -1080,6 +1081,7 @@ enum UTMDataError: Error {
case jitStreamerAttachFailed
case jitStreamerUrlInvalid(String)
case notImplemented
case reconnectFailed
}

extension UTMDataError: LocalizedError {
Expand Down Expand Up @@ -1111,6 +1113,8 @@ extension UTMDataError: LocalizedError {
return String.localizedStringWithFormat(NSLocalizedString("Invalid JitStreamer attach URL:\n%@", comment: "UTMData"), urlString)
case .notImplemented:
return NSLocalizedString("This functionality is not yet implemented.", comment: "UTMData")
case .reconnectFailed:
return NSLocalizedString("Failed to reconnect to the server.", comment: "UTMData")
}
}
}
Expand Down Expand Up @@ -1154,6 +1158,8 @@ struct UTMCapabilities: OptionSet, Codable {
}

#if WITH_REMOTE
private let kReconnectTimeoutSeconds: UInt64 = 5

@MainActor
class UTMRemoteData: UTMData {
/// Remote access client
Expand All @@ -1170,14 +1176,46 @@ class UTMRemoteData: UTMData {

override func listRefresh() async {
busyWorkAsync {
if let capabilities = await self.remoteClient.server.capabilities {
UTMCapabilities.current = capabilities
}
try await self.listRefreshFromRemote()
}
}

func reconnect(to server: UTMRemoteClient.State.SavedServer) async throws {
var reconnectTask: Task<UTMRemoteClient.Remote, any Error>?
let timeoutTask = Task {
try await Task.sleep(nanoseconds: kReconnectTimeoutSeconds * NSEC_PER_SEC)
reconnectTask?.cancel()
}
reconnectTask = busyWorkAsync { [self] in
do {
try await remoteClient.connect(server)
} catch is CancellationError {
throw UTMDataError.reconnectFailed
}
timeoutTask.cancel()
try await listRefreshFromRemote()
return await remoteClient.server
}
// make all active sessions wait on the reconnect
for session in VMSessionState.allActiveSessions.values {
let vm = session.vm as! UTMRemoteSpiceVirtualMachine
Task {
do {
try await vm.reconnectServer {
try await reconnectTask!.value
}
} catch {
session.stop()
}
}
}
_ = try await reconnectTask!.value
}

private func listRefreshFromRemote() async throws {
if let capabilities = await self.remoteClient.server.capabilities {
UTMCapabilities.current = capabilities
}
let ids = try await remoteClient.server.listVirtualMachines()
let items = try await remoteClient.server.getVirtualMachineInformation(for: ids)
await loadVirtualMachines(items.map({ VMRemoteData(fromRemoteItem: $0) }))
Expand Down
22 changes: 13 additions & 9 deletions Remote/UTMRemoteClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,20 @@ actor UTMRemoteClient {
}

func connect(_ server: State.SavedServer) async throws {
var server = server
var isSuccessful = false
let endpoint = server.endpoint ?? NWEndpoint.hostPort(host: .init(server.hostname), port: .init(integerLiteral: UInt16(server.port ?? 0)))
try await keyManager.load()
let connection = try await Connection(endpoint: endpoint, connectionQueue: connectionQueue, identity: keyManager.identity)
let connection = try await Connection(endpoint: endpoint, connectionQueue: connectionQueue, identity: keyManager.identity) { connection, error in
Task {
do {
try await self.local.data.reconnect(to: server)
} catch {
// reconnect failed
await self.state.setConnected(false)
await self.state.showErrorAlert(error.localizedDescription)
}
}
}
defer {
if !isSuccessful {
connection.close()
Expand Down Expand Up @@ -121,6 +130,7 @@ actor UTMRemoteClient {
}
}
self.server = remote
var server = server
await state.setConnected(true)
if !server.shouldSavePassword {
server.password = nil
Expand Down Expand Up @@ -260,7 +270,7 @@ extension UTMRemoteClient {
class Local: LocalInterface {
typealias M = UTMRemoteMessageClient

private let data: UTMRemoteData
fileprivate let data: UTMRemoteData

init(data: UTMRemoteData) {
self.data = data
Expand All @@ -283,12 +293,6 @@ extension UTMRemoteClient {
}
}

func handle(error: Error) {
Task {
await data.showErrorAlert(message: error.localizedDescription)
}
}

private func _handshake(parameters: M.ClientHandshake.Request) async throws -> M.ClientHandshake.Reply {
return .init(version: UTMRemoteMessageClient.version, capabilities: .current)
}
Expand Down
33 changes: 22 additions & 11 deletions Remote/UTMRemoteServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,19 @@ actor UTMRemoteServer {
}
let port = serverPort > 0 ? NWEndpoint.Port(integerLiteral: UInt16(serverPort)) : .any
for try await connection in Connection.advertise(on: port, forServiceType: service, txtRecord: metadata, connectionQueue: connectionQueue, identity: keyManager.identity) {
if let connection = try? await Connection(connection: connection, connectionQueue: connectionQueue) {
let connection = try? await Connection(connection: connection, connectionQueue: connectionQueue) { connection, error in
Task {
guard let fingerprint = connection.fingerprint else {
return
}
if !(error is NWError) {
// connection errors are too noisy
await self.notifyError(error)
}
await self.state.disconnect(fingerprint)
}
}
if let connection = connection {
await newRemoteConnection(connection)
}
}
Expand All @@ -174,7 +186,7 @@ actor UTMRemoteServer {

private func newRemoteConnection(_ connection: Connection) async {
let remoteAddress = connection.connection.endpoint.hostname ?? "\(connection.connection.endpoint)"
guard let fingerprint = connection.peerCertificateChain.first?.fingerprint() else {
guard let fingerprint = connection.fingerprint else {
connection.close()
return
}
Expand Down Expand Up @@ -222,7 +234,7 @@ actor UTMRemoteServer {
}

private func establishConnection(_ connection: Connection) async {
guard let fingerprint = connection.peerCertificateChain.first?.fingerprint() else {
guard let fingerprint = connection.fingerprint else {
connection.close()
return
}
Expand Down Expand Up @@ -282,9 +294,8 @@ actor UTMRemoteServer {
while !group.isEmpty {
switch await group.nextResult() {
case .failure(let error):
if case BroadcastError.connectionError(let error, let fingerprint) = error {
if case BroadcastError.connectionError(_, let fingerprint) = error {
// disconnect any clients who failed to respond
await notifyError(error)
await state.disconnect(fingerprint)
} else {
logger.error("client returned error on broadcast: \(error)")
Expand Down Expand Up @@ -646,12 +657,6 @@ extension UTMRemoteServer {
}
}

func handle(error: Error) {
Task {
await server.notifyError(error)
}
}

@MainActor
private func findVM(withId id: UUID) throws -> VMData {
let vm = data.virtualMachines.first(where: { $0.id == id })
Expand Down Expand Up @@ -940,3 +945,9 @@ extension UTMRemoteServer {
}
}
}

extension Connection {
var fingerprint: [UInt8]? {
return peerCertificateChain.first?.fingerprint()
}
}
16 changes: 12 additions & 4 deletions Remote/UTMRemoteSpiceVirtualMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ final class UTMRemoteSpiceVirtualMachine: UTMSpiceVirtualMachine {

static let capabilities = Capabilities()

private let server: UTMRemoteClient.Remote
private var server: UTMRemoteClient.Remote

init(packageUrl: URL, configuration: UTMQemuConfiguration, isShortcut: Bool) throws {
throw UTMVirtualMachineError.notImplemented
Expand Down Expand Up @@ -142,6 +142,12 @@ final class UTMRemoteSpiceVirtualMachine: UTMSpiceVirtualMachine {
func changeUuid(to uuid: UUID, name: String?, copyingEntry entry: UTMRegistryEntry?) {
// not needed
}

func reconnectServer(_ body: () async throws -> UTMRemoteClient.Remote) async throws {
try await _state.operation(during: .resuming) {
self.server = try await body()
}
}
}

extension UTMRemoteSpiceVirtualMachine {
Expand Down Expand Up @@ -306,12 +312,14 @@ extension UTMRemoteSpiceVirtualMachine {
try await operation(before: [before], during: during, after: after, body: body)
}

func operation(before: Set<UTMVirtualMachineState>, during: UTMVirtualMachineState, after: UTMVirtualMachineState? = nil, body: () async throws -> Void) async throws {
func operation(before: Set<UTMVirtualMachineState>? = nil, during: UTMVirtualMachineState, after: UTMVirtualMachineState? = nil, body: () async throws -> Void) async throws {
while isInOperation {
await Task.yield()
}
guard before.contains(state) else {
throw VMError.operationInProgress
if let before = before {
guard before.contains(state) else {
throw VMError.operationInProgress
}
}
isInOperation = true
remoteState = nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"location" : "https://github.com/utmapp/SwiftConnect",
"state" : {
"branch" : "main",
"revision" : "04ee4b5625653e11c00ee15fe12b46846e02cb95"
"revision" : "4f2241d2ad4e1d99bee6344422ca5c44018f4046"
}
},
{
Expand Down

0 comments on commit 4dca247

Please sign in to comment.