Skip to content

Commit

Permalink
VirtualBox (xlang-ai#46)
Browse files Browse the repository at this point in the history
* Initailize aws support

* Add README for the VM server

* Refactor OSWorld for supporting more cloud services.

* Initialize vmware and aws implementation v1, waiting for verification

* Initlize files for azure, gcp and virtualbox support

* Debug on the VMware provider

* Fix on aws interface mapping

* Fix instance type

* Refactor

* Clean

* Add Azure provider

* hk region; debug

* Fix lock

* Remove print

* Remove key_name requirements when allocating aws vm

* Clean README

* Fix reset

* Fix bugs

* Add VirtualBox and Azure providers

* Add VirtualBox OVF link

* Raise exception on macOS host

* Init RAEDME for VBox

* Update VirtualBox VM download link

* Update requirements and setup.py; Improve robustness on Windows

* Fix network adapter

* Go through on Windows machine

* Add default adapter option

* Fix minor error

---------

Co-authored-by: Timothyxxx <[email protected]>
Co-authored-by: XinyuanWangCS <[email protected]>
Co-authored-by: Tianbao Xie <[email protected]>
  • Loading branch information
4 people authored Jun 17, 2024
1 parent 30050d4 commit 1910646
Show file tree
Hide file tree
Showing 12 changed files with 864 additions and 15 deletions.
9 changes: 6 additions & 3 deletions desktop_env/desktop_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):

def __init__(
self,
provider_name: str = "vmware",
provider_name: str = "virtualbox",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
Expand Down Expand Up @@ -55,8 +55,11 @@ def __init__(
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)

# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) if path_to_vm else \
self.manager.get_vm_path(region)
if path_to_vm:
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
if provider_name in {"vmware", "virtualbox"} else path_to_vm
else:
self.path_to_vm = self.manager.get_vm_path(region)

self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
Expand Down
9 changes: 8 additions & 1 deletion desktop_env/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from desktop_env.providers.vmware.provider import VMwareProvider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider

from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider

def create_vm_manager_and_provider(provider_name: str, region: str):
"""
Expand All @@ -12,7 +15,11 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
return VMwareVMManager(), VMwareProvider(region)
elif provider_name == "virtualbox":
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
return AzureVMManager(), AzureProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")
8 changes: 5 additions & 3 deletions desktop_env/providers/aws/AWS_GUIDELINE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# README for AWS VM Management
# ☁ Configuration of AWS

---

Welcome to the AWS VM Management documentation. Before you proceed with using the code to manage AWS services, please ensure the following variables are set correctly according to your AWS environment.

Expand All @@ -9,8 +11,8 @@ You need to assign values to several variables crucial for the operation of thes
- Example: `'.aws_vms'`
- **`DEFAULT_REGION`**: Default AWS region where your instances will be launched.
- Example: `"us-east-1"`
- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation.
- Example:
- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation. Here we already set the AMI id to the official OSWorld image of Ubuntu supported by us.
- Formatted as follows:
```python
IMAGE_ID_MAP = {
"us-east-1": "ami-09bab251951b4272c",
Expand Down
85 changes: 85 additions & 0 deletions desktop_env/providers/azure/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import threading
import boto3
import psutil

import logging

from desktop_env.providers.base import VMManager

logger = logging.getLogger("desktopenv.providers.azure.AzureVMManager")
logger.setLevel(logging.INFO)

REGISTRY_PATH = '.azure_vms'


def _allocate_vm(region):
raise NotImplementedError


class AzureVMManager(VMManager):
def __init__(self, registry_path=REGISTRY_PATH):
self.registry_path = registry_path
self.lock = threading.Lock()
self.initialize_registry()

def initialize_registry(self):
with self.lock: # Locking during initialization
if not os.path.exists(self.registry_path):
with open(self.registry_path, 'w') as file:
file.write('')

def add_vm(self, vm_path, region):
with self.lock:
with open(self.registry_path, 'r') as file:
lines = file.readlines()
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)

def occupy_vm(self, vm_path, pid, region):
with self.lock:
new_lines = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
registered_vm_path, _ = line.strip().split('|')
if registered_vm_path == "{}@{}".format(vm_path, region):
new_lines.append(f'{registered_vm_path}|{pid}\n')
else:
new_lines.append(line)
with open(self.registry_path, 'w') as file:
file.writelines(new_lines)

def check_and_clean(self):
raise NotImplementedError

def list_free_vms(self, region):
with self.lock: # Lock when reading the registry
free_vms = []
with open(self.registry_path, 'r') as file:
lines = file.readlines()
for line in lines:
vm_path_at_vm_region, pid_str = line.strip().split('|')
vm_path, vm_region = vm_path_at_vm_region.split("@")
if pid_str == "free" and vm_region == region:
free_vms.append((vm_path, pid_str))
return free_vms

def get_vm_path(self, region):
self.check_and_clean()
free_vms_paths = self.list_free_vms(region)
if len(free_vms_paths) == 0:
# No free virtual machine available, generate a new one
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
new_vm_path = _allocate_vm(region)
self.add_vm(new_vm_path, region)
self.occupy_vm(new_vm_path, os.getpid(), region)
return new_vm_path
else:
# Choose the first free virtual machine
chosen_vm_path = free_vms_paths[0][0]
self.occupy_vm(chosen_vm_path, os.getpid(), region)
return chosen_vm_path

205 changes: 205 additions & 0 deletions desktop_env/providers/azure/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import os
import time
from azure.identity import DefaultAzureCredential
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.core.exceptions import ResourceNotFoundError

import logging

from desktop_env.providers.base import Provider

logger = logging.getLogger("desktopenv.providers.azure.AzureProvider")
logger.setLevel(logging.INFO)

WAIT_DELAY = 15
MAX_ATTEMPTS = 10

# To use the Azure provider, download azure-cli by https://learn.microsoft.com/en-us/cli/azure/install-azure-cli,
# use "az login" to log into you Azure account,
# and set environment variable "AZURE_SUBSCRIPTION_ID" to your subscription ID.
# Provide your resource group name and VM name in the format "RESOURCE_GROUP_NAME/VM_NAME" and pass as an argument for "-p".

class AzureProvider(Provider):
def __init__(self, region: str = None):
super().__init__(region)
credential = DefaultAzureCredential()
try:
self.subscription_id = os.environ["AZURE_SUBSCRIPTION_ID"]
except:
logger.error("Azure subscription ID not found. Please set environment variable \"AZURE_SUBSCRIPTION_ID\".")
raise
self.compute_client = ComputeManagementClient(credential, self.subscription_id)
self.network_client = NetworkManagementClient(credential, self.subscription_id)

def start_emulator(self, path_to_vm: str, headless: bool):
logger.info("Starting Azure VM...")
resource_group_name, vm_name = path_to_vm.split('/')

vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView')
power_state = vm.instance_view.statuses[-1].code
if power_state == "PowerState/running":
logger.info("VM is already running.")
return

try:
# Start the instance
for _ in range(MAX_ATTEMPTS):
async_vm_start = self.compute_client.virtual_machines.begin_start(resource_group_name, vm_name)
logger.info(f"VM {path_to_vm} is starting...")
# Wait for the instance to start
async_vm_start.wait(timeout=WAIT_DELAY)
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView')
power_state = vm.instance_view.statuses[-1].code
if power_state == "PowerState/running":
logger.info(f"VM {path_to_vm} is already running.")
break
except Exception as e:
logger.error(f"Failed to start the Azure VM {path_to_vm}: {str(e)}")
raise

def get_ip_address(self, path_to_vm: str) -> str:
logger.info("Getting Azure VM IP address...")
resource_group_name, vm_name = path_to_vm.split('/')

vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name)

for interface in vm.network_profile.network_interfaces:
name=" ".join(interface.id.split('/')[-1:])
sub="".join(interface.id.split('/')[4])

try:
thing=self.network_client.network_interfaces.get(sub, name).ip_configurations

network_card_id = thing[0].public_ip_address.id.split('/')[-1]
public_ip_address = self.network_client.public_ip_addresses.get(resource_group_name, network_card_id)
logger.info(f"VM IP address is {public_ip_address.ip_address}")
return public_ip_address.ip_address

except Exception as e:
logger.error(f"Cannot get public IP for VM {path_to_vm}")
raise

def save_state(self, path_to_vm: str, snapshot_name: str):
print("Saving Azure VM state...")
resource_group_name, vm_name = path_to_vm.split('/')

vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name)

try:
# Backup each disk attached to the VM
for disk in vm.storage_profile.data_disks + [vm.storage_profile.os_disk]:
# Create a snapshot of the disk
snapshot = {
'location': vm.location,
'creation_data': {
'create_option': 'Copy',
'source_uri': disk.managed_disk.id
}
}
async_snapshot_creation = self.compute_client.snapshots.begin_create_or_update(resource_group_name, snapshot_name, snapshot)
async_snapshot_creation.wait(timeout=WAIT_DELAY)

logger.info(f"Successfully created snapshot {snapshot_name} for VM {path_to_vm}.")
except Exception as e:
logger.error(f"Failed to create snapshot {snapshot_name} of the Azure VM {path_to_vm}: {str(e)}")
raise

def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
logger.info(f"Reverting VM to snapshot: {snapshot_name}...")
resource_group_name, vm_name = path_to_vm.split('/')

vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name)

# Stop the VM for disk creation
logger.info(f"Stopping VM: {vm_name}")
async_vm_stop = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name)
async_vm_stop.wait(timeout=WAIT_DELAY) # Wait for the VM to stop

try:
# Get the snapshot
snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name)

# Get the original disk information
original_disk_id = vm.storage_profile.os_disk.managed_disk.id
disk_name = original_disk_id.split('/')[-1]
if disk_name[-1] in ['0', '1']:
new_disk_name = disk_name[:-1] + str(int(disk_name[-1])^1)
else:
new_disk_name = disk_name + "0"

# Delete the disk if it exists
self.compute_client.disks.begin_delete(resource_group_name, new_disk_name).wait(timeout=WAIT_DELAY)

# Make sure the disk is deleted before proceeding to the next step
disk_deleted = False
polling_interval = 10
attempts = 0
while not disk_deleted and attempts < MAX_ATTEMPTS:
try:
self.compute_client.disks.get(resource_group_name, new_disk_name)
# If the above line does not raise an exception, the disk still exists
time.sleep(polling_interval)
attempts += 1
except ResourceNotFoundError:
disk_deleted = True

if not disk_deleted:
logger.error(f"Disk {new_disk_name} deletion timed out.")
raise

# Create a new managed disk from the snapshot
snapshot = self.compute_client.snapshots.get(resource_group_name, snapshot_name)
disk_creation = {
'location': snapshot.location,
'creation_data': {
'create_option': 'Copy',
'source_resource_id': snapshot.id
},
'zones': vm.zones if vm.zones else None # Preserve the original disk's zone
}
async_disk_creation = self.compute_client.disks.begin_create_or_update(resource_group_name, new_disk_name, disk_creation)
restored_disk = async_disk_creation.result() # Wait for the disk creation to complete

vm.storage_profile.os_disk = {
'create_option': vm.storage_profile.os_disk.create_option,
'managed_disk': {
'id': restored_disk.id
}
}

async_vm_creation = self.compute_client.virtual_machines.begin_create_or_update(resource_group_name, vm_name, vm)
async_vm_creation.wait(timeout=WAIT_DELAY)

# Delete the original disk
self.compute_client.disks.begin_delete(resource_group_name, disk_name).wait()

logger.info(f"Successfully reverted to snapshot {snapshot_name}.")
except Exception as e:
logger.error(f"Failed to revert the Azure VM {path_to_vm} to snapshot {snapshot_name}: {str(e)}")
raise

def stop_emulator(self, path_to_vm, region=None):
logger.info(f"Stopping Azure VM {path_to_vm}...")
resource_group_name, vm_name = path_to_vm.split('/')

vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView')
power_state = vm.instance_view.statuses[-1].code
if power_state == "PowerState/deallocated":
print("VM is already stopped.")
return

try:
for _ in range(MAX_ATTEMPTS):
async_vm_deallocate = self.compute_client.virtual_machines.begin_deallocate(resource_group_name, vm_name)
logger.info(f"Stopping VM {path_to_vm}...")
# Wait for the instance to start
async_vm_deallocate.wait(timeout=WAIT_DELAY)
vm = self.compute_client.virtual_machines.get(resource_group_name, vm_name, expand='instanceView')
power_state = vm.instance_view.statuses[-1].code
if power_state == "PowerState/deallocated":
logger.info(f"VM {path_to_vm} is already stopped.")
break
except Exception as e:
logger.error(f"Failed to stop the Azure VM {path_to_vm}: {str(e)}")
raise
11 changes: 11 additions & 0 deletions desktop_env/providers/virtualbox/INSTALL_VITUALBOX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## 💾 Installation of VirtualBox


1. Download the VirtualBox from the [official website](https://www.virtualbox.org/wiki/Downloads). Unfortunately, for Apple chips (M1 chips, M2 chips, etc.), VirtualBox is not supported. You can only use VMware Fusion instead.
2. Install VirtualBox. Just follow the instructions provided by the installer.
For Windows, you also need to append the installation path to the environment variable `PATH` for enabling the `VBoxManage` command. The default installation path is `C:\Program Files\Oracle\VirtualBox`.
3. Verify the successful installation by running the following:
```bash
VBoxManage --version
```
If the installation along with the environment variable set is successful, you will see the version of VirtualBox installed on your system.
Loading

0 comments on commit 1910646

Please sign in to comment.