Skip to content

Commit

Permalink
Improve Assist logic to make sure it has the correct sample rate (#2788)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoncal committed May 24, 2024
1 parent 945690c commit 979b984
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 31 deletions.
7 changes: 7 additions & 0 deletions Sources/App/Assist/AssistView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ struct AssistView: View {
assistSession.inProgress = false
viewModel.onDisappear()
}
.alert(isPresented: $viewModel.showError) {
.init(
title: Text(L10n.errorLabel),
message: Text(viewModel.errorMessage),
dismissButton: .default(Text(L10n.okLabel))
)
}
}

private var closeButton: some View {
Expand Down
29 changes: 22 additions & 7 deletions Sources/App/Assist/AssistViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ final class AssistViewModel: NSObject, ObservableObject {
@Published var showScreenLoader = false
@Published var inputText = ""
@Published var isRecording = false
@Published var showPipelineErrorAlert = false
@Published var showError = false
@Published var errorMessage = ""

private var server: Server
private var audioRecorder: AudioRecorderProtocol
Expand Down Expand Up @@ -84,11 +85,14 @@ final class AssistViewModel: NSObject, ObservableObject {
inputText = ""

audioRecorder.startRecording()
// Wait untill green light from recorder delegate 'didStartRecording'
}

private func startAssistAudioPipeline(audioSampleRate: Double) {
assistService.assist(
source: .audio(
pipelineId: preferredPipelineId,
audioSampleRate: audioRecorder.audioSampleRate
audioSampleRate: audioSampleRate
)
)
}
Expand All @@ -109,7 +113,7 @@ final class AssistViewModel: NSObject, ObservableObject {
assistService.fetchPipelines { [weak self] response in
self?.showScreenLoader = false
guard let self, let response else {
self?.showPipelineError()
self?.showError(message: L10n.Assist.Error.pipelinesResponse)
return
}
if preferredPipelineId.isEmpty {
Expand All @@ -136,9 +140,10 @@ final class AssistViewModel: NSObject, ObservableObject {
}
}

private func showPipelineError() {
private func showError(message: String) {
DispatchQueue.main.async { [weak self] in
self?.showPipelineErrorAlert = true
self?.errorMessage = message
self?.showError = true
}
}

Expand All @@ -151,13 +156,23 @@ final class AssistViewModel: NSObject, ObservableObject {
}

extension AssistViewModel: AudioRecorderDelegate {
func didFailToRecord(error: any Error) {
showError(message: error.localizedDescription)
}

func didOutputSample(data: Data) {
guard canSendAudioData else { return }
assistService.sendAudioData(data)
}

func didStartRecording() {
isRecording = true
func didStartRecording(with sampleRate: Double) {
DispatchQueue.main.async { [weak self] in
self?.isRecording = true
#if DEBUG
self?.appendToChat(.init(content: "didStartRecording(with sampleRate: \(sampleRate)", itemType: .info))
#endif
}
startAssistAudioPipeline(audioSampleRate: sampleRate)
}

func didStopRecording() {
Expand Down
46 changes: 26 additions & 20 deletions Sources/App/Assist/Audio/AudioRecorder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@ import Shared

protocol AudioRecorderProtocol {
var delegate: AudioRecorderDelegate? { get set }
var audioSampleRate: Double { get }
var audioSampleRate: Double? { get }
func startRecording()
func stopRecording()
}

protocol AudioRecorderDelegate: AnyObject {
func didOutputSample(data: Data)
func didStartRecording()
func didStartRecording(with sampleRate: Double)
func didStopRecording()
func didFailToRecord(error: Error)
}

enum AudioRecorderError: Error {
case invalidSampleRate
case captureDeviceUnavailable
}

final class AudioRecorder: NSObject, AudioRecorderProtocol {
weak var delegate: AudioRecorderDelegate?

private(set) var audioSampleRate: Double = 16000
private(set) var audioSampleRate: Double?
private var captureSession: AVCaptureSession?

override init() {
Expand All @@ -34,8 +40,12 @@ final class AudioRecorder: NSObject, AudioRecorderProtocol {
setupAudioRecorder()
guard let captureSession else { return }
DispatchQueue.global().async { [weak self] in
captureSession.startRunning()
self?.delegate?.didStartRecording()
if let audioSampleRate = self?.audioSampleRate {
captureSession.startRunning()
self?.delegate?.didStartRecording(with: audioSampleRate)
} else {
Current.Log.error("No sample rate available to start recording")
}
}
}

Expand All @@ -48,40 +58,39 @@ final class AudioRecorder: NSObject, AudioRecorderProtocol {
let audioSession = AVAudioSession.sharedInstance()
guard let captureDevice = AVCaptureDevice.default(for: .audio) else {
Current.Log.error("Failed to get capture device to record audio for Assist")
delegate?.didFailToRecord(error: AudioRecorderError.captureDeviceUnavailable)
return
}

do {
try audioSession.setActive(false)
try audioSession.setCategory(.record, mode: .default)
try audioSession.setPreferredSampleRate(16000)
try audioSession.setPreferredOutputNumberOfChannels(1)

try audioSession.setPreferredSampleRate(16000.0)
try audioSession.setActive(true)
let audioInput = try AVCaptureDeviceInput(device: captureDevice)

captureSession = AVCaptureSession()
captureSession?.automaticallyConfiguresApplicationAudioSession = false
captureSession?.addInput(audioInput)

Current.Log.info("Audio sample rate: \(audioSession.sampleRate)")
audioSampleRate = audioSession.sampleRate
if audioSession.sampleRate == 0 {
throw AudioRecorderError.invalidSampleRate
} else {
audioSampleRate = audioSession.sampleRate
}

let audioOutput = AVCaptureAudioDataOutput()

audioOutput.setSampleBufferDelegate(self, queue: DispatchQueue.global(qos: .userInteractive))
captureSession?.addOutput(audioOutput)
} catch {
Current.Log.error("Error starting audio streaming: \(error.localizedDescription)")
delegate?.didFailToRecord(error: error)
}
}

private func registerForRecordingNotifications() {
NotificationCenter.default.addObserver(
self,
selector: #selector(sessionDidStartRunning),
name: .AVCaptureSessionDidStartRunning,
object: captureSession
)

NotificationCenter.default.addObserver(
self,
selector: #selector(sessionDidStopRunning),
Expand All @@ -97,10 +106,6 @@ final class AudioRecorder: NSObject, AudioRecorderProtocol {
)
}

@objc private func sessionDidStartRunning(notification: Notification) {
delegate?.didStartRecording()
}

@objc private func sessionDidStopRunning(notification: Notification) {
delegate?.didStopRecording()
}
Expand All @@ -109,6 +114,7 @@ final class AudioRecorder: NSObject, AudioRecorderProtocol {
if let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError {
let message = "AVCaptureSession runtime error: \(error)"
Current.Log.error(message)
delegate?.didFailToRecord(error: error)
}
delegate?.didStopRecording()
}
Expand Down
5 changes: 3 additions & 2 deletions Sources/App/Assist/Tests/AssistViewModel.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ final class AssistViewModelTests: XCTestCase {
}

func testDidStartRecording() {
sut.didStartRecording()
XCTAssertTrue(sut.isRecording)
sut.preferredPipelineId = "2"
sut.didStartRecording(with: 16000)
XCTAssertEqual(mockAssistService.assistSource, .audio(pipelineId: "2", audioSampleRate: 16000.0))
}

func testDidStopRecording() {
Expand Down
2 changes: 1 addition & 1 deletion Sources/App/Assist/Tests/Mocks/MockAudioRecorder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

final class MockAudioRecorder: AudioRecorderProtocol {
weak var delegate: AudioRecorderDelegate?
var audioSampleRate: Double = 16000
var audioSampleRate: Double?

var startRecordingCalled = false
var stopRecordingCalled = false
Expand Down
3 changes: 2 additions & 1 deletion Sources/App/Resources/en.lproj/Localizable.strings
Original file line number Diff line number Diff line change
Expand Up @@ -853,4 +853,5 @@ Home Assistant is free and open source home automation software with a focus on
"widgets.open_page.description" = "Open a frontend page in Home Assistant.";
"widgets.open_page.not_configured" = "No Pages Available";
"widgets.open_page.title" = "Open Page";
"yes_label" = "Yes";
"assist.error.pipelines_response" = "Failed to obtain Assist pipelines, please check your pipelines configuration.";
"yes_label" = "Yes";
4 changes: 4 additions & 0 deletions Sources/Shared/Resources/Swiftgen/Strings.swift
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ public enum L10n {
}

public enum Assist {
public enum Error {
/// Failed to obtain Assist pipelines, please check your pipelines configuration.
public static var pipelinesResponse: String { return L10n.tr("Localizable", "assist.error.pipelines_response") }
}
public enum PipelinesPicker {
/// Assist Pipelines
public static var title: String { return L10n.tr("Localizable", "assist.pipelines_picker.title") }
Expand Down

0 comments on commit 979b984

Please sign in to comment.