Skip to content

Commit

Permalink
Support for several edge cases during model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
SkalskiP committed Dec 20, 2022
1 parent 6135b36 commit 6aa8cc9
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 46 deletions.
24 changes: 20 additions & 4 deletions src/ai/RoboflowAPIObjectDetector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,31 @@ export interface DetectedObject {

export class RoboflowAPIObjectDetector {

public static loadModel(roboflowAPIDetails: RoboflowAPIDetails) {
public static loadModel(
roboflowAPIDetails: RoboflowAPIDetails,
onSuccess?: () => any,
onFailure?: () => any
) {
store.dispatch(updateRoboflowAPIDetails(roboflowAPIDetails));
store.dispatch(updateActiveLabelType(LabelType.RECT));
const activeLabelType: LabelType = LabelsSelector.getActiveLabelType();
if (activeLabelType === LabelType.RECT) {
AIRoboflowAPIObjectDetectionActions.detectRectsForActiveImage();
const activeImageData: ImageData = LabelsSelector.getActiveImageData();

const wrappedOnFailure = () => {
store.dispatch(updateRoboflowAPIDetails({model: '', key: ''}));
onFailure()
}

RoboflowAPIObjectDetector.predict(activeImageData, onSuccess, wrappedOnFailure)
}
}

public static predict(imageData: ImageData, callback?: (predictions: DetectedObject[]) => any) {
public static predict(
imageData: ImageData,
onSuccess?: (predictions: DetectedObject[]) => any,
onFailure?: () => any
) {
const roboflowAPIDetails: RoboflowAPIDetails = AISelector.getRoboflowAPIDetails();
FileUtil.loadImageBase64(imageData.fileData).then((data) => {
axios({
Expand All @@ -65,8 +80,9 @@ export class RoboflowAPIObjectDetector {
class: prediction.class
}
});
callback(predictions)
onSuccess(predictions)
})
.catch(onFailure)
})
}
}
1 change: 1 addition & 0 deletions src/data/enums/Notification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ export enum Notification {
ANNOTATION_FILE_PARSE_ERROR = 6,
ANNOTATION_IMPORT_ASSERTION_ERROR = 7,
UNSUPPORTED_INFERENCE_SERVER_MESSAGE = 8,
ROBOFLOW_INFERENCE_SERVER_ERROR = 9,
}
7 changes: 6 additions & 1 deletion src/data/info/NotificationsData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ export const NotificationsDataMap: ExportFormatDataMap = {
},
[Notification.UNSUPPORTED_INFERENCE_SERVER_MESSAGE]: {
header: 'Selected inference server is not yet supported',
description: 'Integration with selected inference server is still under construction 🚧. Stay tuned for more ' +
description: 'Integration with selected inference server is still under construction. Stay tuned for more ' +
'updates on our GitHub.'
},
[Notification.ROBOFLOW_INFERENCE_SERVER_ERROR]: {
header: 'Roboflow connection failed',
description: 'Looks like we ware unable to connect to your Roboflow model. Please, make sure that the model ' +
'specification and Roboflow API key, are correct.'
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
@import '../../../settings/Settings';

.right-container {
display: flex;
flex-direction: column;
flex-wrap: nowrap;
align-items: center;
align-content: flex-start;

align-self: stretch;
flex: 1;

.loader {
display: flex;
align-items: center;
align-content: center;
justify-content: center;

align-self: stretch;
flex: 1;
}

.message {
align-self: stretch;
color: white;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import './ConnectInferenceServerPopup.scss'
import { StyledTextField } from '../../Common/StyledTextField/StyledTextField';
import { RoboflowAPIDetails } from '../../../store/ai/types';
import { RoboflowAPIObjectDetector } from '../../../ai/RoboflowAPIObjectDetector';
import { ClipLoader } from 'react-spinners';
import { CSSHelper } from '../../../logic/helpers/CSSHelper';

interface IProps {
roboflowAPIDetails: RoboflowAPIDetails;
Expand All @@ -29,6 +31,7 @@ const ConnectInferenceServerPopup: React.FC<IProps> = (
) => {
// general
const [currentServerType, setCurrentServerType] = useState(InferenceServerType.ROBOFLOW);
const [modelIsLoadingStatus, setModelIsLoadingStatus] = useState(false);

// roboflow
const [roboflowModel, setRoboflowModel] = useState(roboflowAPIDetails.model);
Expand All @@ -47,6 +50,8 @@ const ConnectInferenceServerPopup: React.FC<IProps> = (
}

const disableAcceptButton = () => {
if (modelIsLoadingStatus) return true;

switch(currentServerType) {
case InferenceServerType.ROBOFLOW:
return roboflowModel === '' || roboflowKey === ''
Expand All @@ -56,15 +61,23 @@ const ConnectInferenceServerPopup: React.FC<IProps> = (
}

const onAccept = () => {
if (disableAcceptButton()) {
return;
if (disableAcceptButton()) return;

const onSuccess = () => {
PopupActions.close();
}

const onFailure = () => {
submitNewNotificationAction(NotificationUtil.createErrorNotification(
NotificationsDataMap[Notification.ROBOFLOW_INFERENCE_SERVER_ERROR]));
setModelIsLoadingStatus(false);
}

setModelIsLoadingStatus(true);
RoboflowAPIObjectDetector.loadModel({
model: roboflowModel,
key: roboflowKey
})
PopupActions.close();
}, onSuccess, onFailure)
};

const onReject = () => {
Expand All @@ -79,45 +92,58 @@ const ConnectInferenceServerPopup: React.FC<IProps> = (
setRoboflowKey(event.target.value)
}

const renderLoader = () => {
return(<div className='loader'>
<ClipLoader
size={40}
color={CSSHelper.getLeadingColor()}
loading={true}
/>
</div>)
}

const renderRoboflow = () => {
return <>
<div className='message'>
Provide details of the Roboflow model you want to run over tha API, as well as your API key.
</div>
<div className='details'>
<StyledTextField
variant='standard'
id={'roboflow-model'}
autoComplete={'off'}
autoFocus={true}
type={'text'}
margin={'dense'}
label={'roboflow model'}
value={roboflowModel}
onChange={roboflowModelOnChangeCallback}
style={{ width: 280 }}
InputLabelProps={{ shrink: true }}
/>
<StyledTextField
variant='standard'
id={'roboflow-api- key'}
autoComplete={'off'}
autoFocus={true}
type={'password'}
margin={'dense'}
label={'roboflow api key'}
value={roboflowKey}
onChange={roboflowKeyOnChangeCallback}
style={{ width: 280 }}
InputLabelProps={{ shrink: true }}
/>
</div>
</>;
}

const renderContent = (): JSX.Element => {
if (modelIsLoadingStatus) {
return renderLoader()
}
if (currentServerType === InferenceServerType.ROBOFLOW) {
return <>
<div className='message'>
Provide details of the Roboflow model you want to run over tha API, as well as your API key.
</div>
<div className='details'>
<StyledTextField
variant='standard'
id={'key'}
autoComplete={'off'}
autoFocus={true}
type={'text'}
margin={'dense'}
label={'roboflow model'}
value={roboflowModel}
onChange={roboflowModelOnChangeCallback}
style={{ width: 280 }}
InputLabelProps={{
shrink: true,
}}
/>
<StyledTextField
variant='standard'
id={'key'}
autoComplete={'off'}
autoFocus={true}
type={'password'}
margin={'dense'}
label={'roboflow api key'}
value={roboflowKey}
onChange={roboflowKeyOnChangeCallback}
style={{ width: 280 }}
InputLabelProps={{
shrink: true,
}}
/>
</div>
</>;
return renderRoboflow();
}
return <div className='load-model-popup-content'/>
};
Expand All @@ -136,6 +162,7 @@ const ConnectInferenceServerPopup: React.FC<IProps> = (
/>
})
}

return (
<GenericSideMenuPopup
title={InferenceServerDataMap[currentServerType].name}
Expand Down

0 comments on commit 6aa8cc9

Please sign in to comment.