forked from xlang-ai/OSWorld
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initailize aws support * Add README for the VM server * Refactor OSWorld for supporting more cloud services. * Initialize vmware and aws implementation v1, waiting for verification * Initlize files for azure, gcp and virtualbox support * Debug on the VMware provider * Fix on aws interface mapping * Fix instance type * Refactor * Clean * Add Azure provider * hk region; debug * Fix lock * Remove print * Remove key_name requirements when allocating aws vm * Clean README * Fix reset * Fix bugs * Add VirtualBox and Azure providers * Add VirtualBox OVF link * Raise exception on macOS host * Init RAEDME for VBox * Update VirtualBox VM download link * Update requirements and setup.py; Improve robustness on Windows * Fix network adapter * Go through on Windows machine * Add default adapter option * Fix minor error --------- Co-authored-by: Timothyxxx <[email protected]> Co-authored-by: XinyuanWangCS <[email protected]> Co-authored-by: Tianbao Xie <[email protected]>
- Loading branch information
1 parent
30050d4
commit 1910646
Showing
12 changed files
with
864 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import threading | ||
import boto3 | ||
import psutil | ||
|
||
import logging | ||
|
||
from desktop_env.providers.base import VMManager | ||
|
||
logger = logging.getLogger("desktopenv.providers.azure.AzureVMManager") | ||
logger.setLevel(logging.INFO) | ||
|
||
REGISTRY_PATH = '.azure_vms' | ||
|
||
|
||
def _allocate_vm(region): | ||
raise NotImplementedError | ||
|
||
|
||
class AzureVMManager(VMManager): | ||
def __init__(self, registry_path=REGISTRY_PATH): | ||
self.registry_path = registry_path | ||
self.lock = threading.Lock() | ||
self.initialize_registry() | ||
|
||
def initialize_registry(self): | ||
with self.lock: # Locking during initialization | ||
if not os.path.exists(self.registry_path): | ||
with open(self.registry_path, 'w') as file: | ||
file.write('') | ||
|
||
def add_vm(self, vm_path, region): | ||
with self.lock: | ||
with open(self.registry_path, 'r') as file: | ||
lines = file.readlines() | ||
vm_path_at_vm_region = "{}@{}".format(vm_path, region) | ||
new_lines = lines + [f'{vm_path_at_vm_region}|free\n'] | ||
with open(self.registry_path, 'w') as file: | ||
file.writelines(new_lines) | ||
|
||
def occupy_vm(self, vm_path, pid, region): | ||
with self.lock: | ||
new_lines = [] | ||
with open(self.registry_path, 'r') as file: | ||
lines = file.readlines() | ||
for line in lines: | ||
registered_vm_path, _ = line.strip().split('|') | ||
if registered_vm_path == "{}@{}".format(vm_path, region): | ||
new_lines.append(f'{registered_vm_path}|{pid}\n') | ||
else: | ||
new_lines.append(line) | ||
with open(self.registry_path, 'w') as file: | ||
file.writelines(new_lines) | ||
|
||
def check_and_clean(self): | ||
raise NotImplementedError | ||
|
||
def list_free_vms(self, region): | ||
with self.lock: # Lock when reading the registry | ||
free_vms = [] | ||
with open(self.registry_path, 'r') as file: | ||
lines = file.readlines() | ||
for line in lines: | ||
vm_path_at_vm_region, pid_str = line.strip().split('|') | ||
vm_path, vm_region = vm_path_at_vm_region.split("@") | ||
if pid_str == "free" and vm_region == region: | ||
free_vms.append((vm_path, pid_str)) | ||
return free_vms | ||
|
||
def get_vm_path(self, region): | ||
self.check_and_clean() | ||
free_vms_paths = self.list_free_vms(region) | ||
if len(free_vms_paths) == 0: | ||
# No free virtual machine available, generate a new one | ||
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕") | ||
new_vm_path = _allocate_vm(region) | ||
self.add_vm(new_vm_path, region) | ||
self.occupy_vm(new_vm_path, os.getpid(), region) | ||
return new_vm_path | ||
else: | ||
# Choose the first free virtual machine | ||
chosen_vm_path = free_vms_paths[0][0] | ||
self.occupy_vm(chosen_vm_path, os.getpid(), region) | ||
return chosen_vm_path | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import os | ||
import time | ||
from azure.identity import DefaultAzureCredential | ||
from azure.mgmt.compute import ComputeManagementClient | ||
from azure.mgmt.network import NetworkManagementClient | ||
from azure.core.exceptions import ResourceNotFoundError | ||
|
||
import logging | ||
|
||
from desktop_env.providers.base import Provider | ||
|
||
logger = logging.getLogger("desktopenv.providers.azure.AzureProvider") | ||
logger.setLevel(logging.INFO) | ||
|
||
WAIT_DELAY = 15 | ||
MAX_ATTEMPTS = 10 | ||
|
||
# To use the Azure provider, download azure-cli by https://learn.microsoft.com/en-us/cli/azure/install-azure-cli, | ||
# use "az login" to log into you Azure account, | ||
# and set environment variable "AZURE_SUBSCRIPTION_ID" to your subscription ID. | ||
# Provide your resource group name and VM name in the format "RESOURCE_GROUP_NAME/VM_NAME" and pass as an argument for "-p". | ||
|
||
class AzureProvider(Provider): | ||
def __init__(self, region: str = None): | ||
super().__init__(region) | ||
credential = DefaultAzureCredential() | ||
try: | ||
self.subscription_id = os.environ["AZURE_SUBSCRIPTION_ID"] | ||
except: | ||
logger.error("Azure subscription ID not found. Please set environment variable \"AZURE_SUBSCRIPTION_ID\".") | ||
raise | ||
self.compute_client = ComputeManagementClient(credential, self.subscription_id) | ||
self.network_client = NetworkManagementClient(credential, self.subscription_id) | ||
|
||
def start_emulator(self, path_to_vm: str, headless: bool): | ||
logger.info("Starting Azure VM...") | ||
resource_group_name, vm_name = path_to_vm.split('/') | ||
|
||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') | ||
power_state = vm.instance_view.statuses[-1].code | ||
if power_state == "PowerState/running": | ||
logger.info("VM is already running.") | ||
return | ||
|
||
try: | ||
# Start the instance | ||
for _ in range(MAX_ATTEMPTS): | ||
async_vm_start = self.compute_client.virtual_machines.begin_start(resource_group_name, vm_name) | ||
logger.info(f"VM {path_to_vm} is starting...") | ||
# Wait for the instance to start | ||
async_vm_start.wait(timeout=WAIT_DELAY) | ||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') | ||
power_state = vm.instance_view.statuses[-1].code | ||
if power_state == "PowerState/running": | ||
logger.info(f"VM {path_to_vm} is already running.") | ||
break | ||
except Exception as e: | ||
logger.error(f"Failed to start the Azure VM {path_to_vm}: {str(e)}") | ||
raise | ||
|
||
def get_ip_address(self, path_to_vm: str) -> str: | ||
logger.info("Getting Azure VM IP address...") | ||
resource_group_name, vm_name = path_to_vm.split('/') | ||
|
||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) | ||
|
||
for interface in vm.network_profile.network_interfaces: | ||
name=" ".join(interface.id.split('/')[-1:]) | ||
sub="".join(interface.id.split('/')[4]) | ||
|
||
try: | ||
thing=self.network_client.network_interfaces.get(sub, name).ip_configurations | ||
|
||
network_card_id = thing[0].public_ip_address.id.split('/')[-1] | ||
public_ip_address = self.network_client.public_ip_addresses.get(resource_group_name, network_card_id) | ||
logger.info(f"VM IP address is {public_ip_address.ip_address}") | ||
return public_ip_address.ip_address | ||
|
||
except Exception as e: | ||
logger.error(f"Cannot get public IP for VM {path_to_vm}") | ||
raise | ||
|
||
def save_state(self, path_to_vm: str, snapshot_name: str): | ||
print("Saving Azure VM state...") | ||
resource_group_name, vm_name = path_to_vm.split('/') | ||
|
||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) | ||
|
||
try: | ||
# Backup each disk attached to the VM | ||
for disk in vm.storage_profile.data_disks + [vm.storage_profile.os_disk]: | ||
# Create a snapshot of the disk | ||
snapshot = { | ||
'location': vm.location, | ||
'creation_data': { | ||
'create_option': 'Copy', | ||
'source_uri': disk.managed_disk.id | ||
} | ||
} | ||
async_snapshot_creation = self.compute_client.snapshots.begin_create_or_update(resource_group_name, snapshot_name, snapshot) | ||
async_snapshot_creation.wait(timeout=WAIT_DELAY) | ||
|
||
logger.info(f"Successfully created snapshot {snapshot_name} for VM {path_to_vm}.") | ||
except Exception as e: | ||
logger.error(f"Failed to create snapshot {snapshot_name} of the Azure VM {path_to_vm}: {str(e)}") | ||
raise | ||
|
||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): | ||
logger.info(f"Reverting VM to snapshot: {snapshot_name}...") | ||
resource_group_name, vm_name = path_to_vm.split('/') | ||
|
||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name) | ||
|
||
# Stop the VM for disk creation | ||
logger.info(f"Stopping VM: {vm_name}") | ||
async_vm_stop = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) | ||
async_vm_stop.wait(timeout=WAIT_DELAY) # Wait for the VM to stop | ||
|
||
try: | ||
# Get the snapshot | ||
snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) | ||
|
||
# Get the original disk information | ||
original_disk_id = vm.storage_profile.os_disk.managed_disk.id | ||
disk_name = original_disk_id.split('/')[-1] | ||
if disk_name[-1] in ['0', '1']: | ||
new_disk_name = disk_name[:-1] + str(int(disk_name[-1])^1) | ||
else: | ||
new_disk_name = disk_name + "0" | ||
|
||
# Delete the disk if it exists | ||
self.compute_client.disks.begin_delete(resource_group_name, new_disk_name).wait(timeout=WAIT_DELAY) | ||
|
||
# Make sure the disk is deleted before proceeding to the next step | ||
disk_deleted = False | ||
polling_interval = 10 | ||
attempts = 0 | ||
while not disk_deleted and attempts < MAX_ATTEMPTS: | ||
try: | ||
self.compute_client.disks.get(resource_group_name, new_disk_name) | ||
# If the above line does not raise an exception, the disk still exists | ||
time.sleep(polling_interval) | ||
attempts += 1 | ||
except ResourceNotFoundError: | ||
disk_deleted = True | ||
|
||
if not disk_deleted: | ||
logger.error(f"Disk {new_disk_name} deletion timed out.") | ||
raise | ||
|
||
# Create a new managed disk from the snapshot | ||
snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name) | ||
disk_creation = { | ||
'location': snapshot.location, | ||
'creation_data': { | ||
'create_option': 'Copy', | ||
'source_resource_id': snapshot.id | ||
}, | ||
'zones': vm.zones if vm.zones else None # Preserve the original disk's zone | ||
} | ||
async_disk_creation = self.compute_client.disks.begin_create_or_update(resource_group_name, new_disk_name, disk_creation) | ||
restored_disk = async_disk_creation.result() # Wait for the disk creation to complete | ||
|
||
vm.storage_profile.os_disk = { | ||
'create_option': vm.storage_profile.os_disk.create_option, | ||
'managed_disk': { | ||
'id': restored_disk.id | ||
} | ||
} | ||
|
||
async_vm_creation = self.compute_client.virtual_machines.begin_create_or_update(resource_group_name, vm_name, vm) | ||
async_vm_creation.wait(timeout=WAIT_DELAY) | ||
|
||
# Delete the original disk | ||
self.compute_client.disks.begin_delete(resource_group_name, disk_name).wait() | ||
|
||
logger.info(f"Successfully reverted to snapshot {snapshot_name}.") | ||
except Exception as e: | ||
logger.error(f"Failed to revert the Azure VM {path_to_vm} to snapshot {snapshot_name}: {str(e)}") | ||
raise | ||
|
||
def stop_emulator(self, path_to_vm, region=None): | ||
logger.info(f"Stopping Azure VM {path_to_vm}...") | ||
resource_group_name, vm_name = path_to_vm.split('/') | ||
|
||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') | ||
power_state = vm.instance_view.statuses[-1].code | ||
if power_state == "PowerState/deallocated": | ||
print("VM is already stopped.") | ||
return | ||
|
||
try: | ||
for _ in range(MAX_ATTEMPTS): | ||
async_vm_deallocate = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name) | ||
logger.info(f"Stopping VM {path_to_vm}...") | ||
# Wait for the instance to start | ||
async_vm_deallocate.wait(timeout=WAIT_DELAY) | ||
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView') | ||
power_state = vm.instance_view.statuses[-1].code | ||
if power_state == "PowerState/deallocated": | ||
logger.info(f"VM {path_to_vm} is already stopped.") | ||
break | ||
except Exception as e: | ||
logger.error(f"Failed to stop the Azure VM {path_to_vm}: {str(e)}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
## 💾 Installation of VirtualBox | ||
|
||
|
||
1. Download the VirtualBox from the [official website](https://www.virtualbox.org/wiki/Downloads). Unfortunately, for Apple chips (M1 chips, M2 chips, etc.), VirtualBox is not supported. You can only use VMware Fusion instead. | ||
2. Install VirtualBox. Just follow the instructions provided by the installer. | ||
For Windows, you also need to append the installation path to the environment variable `PATH` for enabling the `VBoxManage` command. The default installation path is `C:\Program Files\Oracle\VirtualBox`. | ||
3. Verify the successful installation by running the following: | ||
```bash | ||
VBoxManage --version | ||
``` | ||
If the installation along with the environment variable set is successful, you will see the version of VirtualBox installed on your system. |
Oops, something went wrong.