Skip to content

Commit

Permalink
merge asset callback
Browse files Browse the repository at this point in the history
  • Loading branch information
AlvaroHG committed Oct 30, 2024
2 parents 9a013c7 + 1a0a95c commit 37a921a
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 179 deletions.
42 changes: 26 additions & 16 deletions ai2thor/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ def __init__(
server_timeout: Optional[float] = 100.0,
server_start_timeout: float = 300.0,
# objaverse_asset_ids=[], TODO add and implement when objaverse.load_thor_objects is available
action_hook_runner=None,
metadata_hook: Optional[MetadataHook] = None,
before_action_callback=None,
metadata_callback: Optional[MetadataHook] = None,
**unity_initialization_parameters,
):
self.receptacle_nearest_pivot_points = {}
Expand Down Expand Up @@ -443,18 +443,28 @@ def __init__(
)
)

self.action_hook_runner = action_hook_runner
self.action_hooks = (
if "action_hook_runner" in unity_initialization_parameters:
raise ValueError(
f"Deprecated argument 'action_hook_runner'. Use 'before_action_callback' instead."
)

if "metadata_hook" in unity_initialization_parameters:
raise ValueError(
f"Deprecated argument 'metadata_hook'. Use 'metadata_callback' instead."
)

self.before_action_callback = before_action_callback
self.action_callbacks = (
{
func
for func in dir(action_hook_runner)
if callable(getattr(action_hook_runner, func)) and not func.startswith("__")
for func in dir(before_action_callback)
if callable(getattr(before_action_callback, func)) and not func.startswith("__")
}
if self.action_hook_runner is not None
if self.before_action_callback is not None
else None
)

self.metadata_hook = metadata_hook
self.metadata_callback = metadata_callback

if self.gpu_device is not None:
# numbers.Integral works for numpy.int32/64 and Python int
Expand Down Expand Up @@ -971,11 +981,11 @@ def multi_step_physics(self, action, timeStep=0.05, max_steps=20):

return events

def run_action_hook(self, action):
if self.action_hooks is not None and action["action"] in self.action_hooks:
def run_before_action_callback(self, action):
if self.action_callbacks is not None and action["action"] in self.action_callbacks:
try:
# print(f"action hooks: {self.action_hooks}")
method = getattr(self.action_hook_runner, action["action"])
method = getattr(self.before_action_callback, action["action"])
event = method(action, self)
if isinstance(event, list):
self.last_event = event[-1]
Expand All @@ -984,18 +994,18 @@ def run_action_hook(self, action):
except AttributeError:
traceback.print_stack()
raise NotImplementedError(
"Action Hook Runner `{}` does not implement method `{}`,"
"Action Callback `{}` does not implement method `{}`,"
" actions hooks are meant to run before an action, make sure that `action_hook_runner`"
" passed to the controller implements a method for the desired action.".format(
self.action_hook_runner.__class__.__name__, action["action"]
self.before_action_callback.__class__.__name__, action["action"]
)
)
return True
return False

def run_metadata_hook(self, metadata: MetadataWrapper) -> bool:
if self.metadata_hook is not None:
out = self.metadata_hook(metadata=metadata, controller=self)
if self.metadata_callback is not None:
out = self.metadata_callback(metadata=metadata, controller=self)
assert (
out is None
), "`metadata_hook` must return `None` and change the metadata in place."
Expand Down Expand Up @@ -1043,7 +1053,7 @@ def step(self, action: Union[str, Dict[str, Any]] = None, **action_args):
# not deleting to allow for older builds to continue to work
# del action[old]

self.run_action_hook(action)
self.run_before_action_callback(action)

self.server.send(action)
try:
Expand Down
54 changes: 53 additions & 1 deletion ai2thor/hooks/procedural_asset_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def create_assets_if_not_exist(
# return evt


class ProceduralAssetHookRunner:
class ProceduralAssetActionCallback:
def __init__(
self,
asset_directory: str,
Expand All @@ -278,6 +278,7 @@ def __init__(
self.target_dir = target_dir
self.extension = extension
self.verbose = verbose
self.last_asset_id_set = set()

def Initialize(self, action, controller):
if self.asset_limit > 0:
Expand All @@ -288,6 +289,10 @@ def Initialize(self, action, controller):
def CreateHouse(self, action, controller):
house = action["house"]
asset_ids = get_all_asset_ids_recursively(house["objects"], [])
asset_ids_set = set(asset_ids)
if not asset_ids_set.issubset(self.last_asset_id_set):
controller.step(action="DeleteLRUFromProceduralCache", assetLimit=0)
self.last_asset_id_set = set(asset_ids)
return create_assets_if_not_exist(
controller=controller,
asset_ids=asset_ids,
Expand Down Expand Up @@ -330,6 +335,53 @@ def GetHouseFromTemplate(self, action, controller):
)


class DownloadObjaverseActionCallback(object):
def __init__(
self,
asset_dataset_version,
asset_download_path,
target_dir="processed_models",
asset_symlink=True,
load_file_in_unity=False,
stop_if_fail=False,
asset_limit=-1,
extension=None,
verbose=True,
):
self.asset_download_path = asset_download_path
self.asset_symlink = asset_symlink
self.stop_if_fail = stop_if_fail
self.asset_limit = asset_limit
self.load_file_in_unity = load_file_in_unity
self.target_dir = target_dir
self.extension = extension
self.verbose = verbose
self.last_asset_id_set = set()
dsc = DatasetSaveConfig(
VERSION=asset_dataset_version,
BASE_PATH=asset_download_path,
)
self.asset_path = load_assets_path(dsc)

def CreateHouse(self, action, controller):
house = action["house"]
asset_ids = get_all_asset_ids_recursively(house["objects"], [])
asset_ids_set = set(asset_ids)
if not asset_ids_set.issubset(self.last_asset_id_set):
controller.step(action="DeleteLRUFromProceduralCache", assetLimit=0)
self.last_asset_id_set = set(asset_ids)
return create_assets_if_not_exist(
controller=controller,
asset_ids=asset_ids,
asset_directory=self.asset_path,
copy_to_dir=os.path.join(controller._build.base_dir, self.target_dir),
asset_symlink=self.asset_symlink,
stop_if_fail=self.stop_if_fail,
load_file_in_unity=self.load_file_in_unity,
extension=self.extension,
verbose=self.verbose,
)

def download_with_progress_bar(save_path: str, url: str, verbose: bool = False):
os.makedirs(os.path.dirname(save_path), exist_ok=True)

Expand Down
2 changes: 1 addition & 1 deletion ai2thor/tests/data/arm-metadata-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
"type": "number"
},
"isStanding": {
"type": "boolean"
"type": ["boolean", "null"]
},
"inHighFrictionArea": {
"type": "boolean"
Expand Down
7 changes: 4 additions & 3 deletions ai2thor/tests/test_unity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,16 +1418,17 @@ def test_teleport_stretch(controller):
agent = "stretch"

event = controller.reset(agentMode=agent)
assert event.metadata["agent"]["isStanding"] is False, agent + " cannot stand!"
assert event.metadata["agent"]["isStanding"] is None, (
agent + " cannot stand so this should be None/null!"
)

# Only degrees of freedom on the locobot
for action in ["Teleport", "TeleportFull"]:
event = controller.step(
action=action,
position=dict(x=-1.5, y=0.9, z=-1.5),
rotation=dict(x=0, y=90, z=0),
horizon=30,
standing=True,
standing=None,
)

print(f"Error Message: {event.metadata['errorMessage']}")
Expand Down
12 changes: 6 additions & 6 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4727,10 +4727,10 @@ def test_create_prefab(ctx, json_path):
def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetHookRunner
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback
from objathor.asset_conversion.util import view_asset_in_thor

hook_runner = ProceduralAssetHookRunner(
hook_runner = ProceduralAssetActionCallback(
asset_directory=asset_dir,
asset_symlink=True,
verbose=True,
Expand All @@ -4747,7 +4747,7 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
height=300,
server_class=ai2thor.fifo_server.FifoServer,
visibilityScheme="Distance",
action_hook_runner=hook_runner,
before_action_callback=hook_runner,
)

# TODO bug why skybox is not changing? from just procedural pipeline
Expand Down Expand Up @@ -4817,9 +4817,9 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_limit=1):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetHookRunner
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback

hook_runner = ProceduralAssetHookRunner(
hook_runner = ProceduralAssetActionCallback(
asset_directory=asset_dir, asset_symlink=True, verbose=True, asset_limit=1
)
controller = ai2thor.controller.Controller(
Expand All @@ -4834,7 +4834,7 @@ def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_
height=300,
server_class=ai2thor.wsgi_server.WsgiServer,
visibilityScheme="Distance",
action_hook_runner=hook_runner,
before_action_callback=hook_runner,
)
asset_ids = asset_ids.split(",")
with open(house_path, "r") as f:
Expand Down
2 changes: 1 addition & 1 deletion unity/Assets/Scripts/AgentManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ public void Initialize(ServerAction action) {
: action.dynamicServerAction.agentInitializationParams
);
Debug.Log(
$"Initialize of AgentController. lastActionSuccess: {primaryAgent.lastActionSuccess}, errorMessage: {primaryAgent.errorMessage}, actionReturn: {primaryAgent.actionReturn}, agentState: {primaryAgent.agentState}"
$"Initialize of AgentController.lastAction: {primaryAgent.lastAction} lastActionSuccess: {primaryAgent.lastActionSuccess}, errorMessage: {primaryAgent.errorMessage}, actionReturn: {primaryAgent.actionReturn}, agentState: {primaryAgent.agentState}"
);
Time.fixedDeltaTime = action.fixedDeltaTime.GetValueOrDefault(Time.fixedDeltaTime);
if (action.targetFrameRate > 0) {
Expand Down
74 changes: 74 additions & 0 deletions unity/Assets/Scripts/ArmAgentController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,80 @@ public virtual IEnumerator RotateAgent(
);
}

public override void Teleport(
Vector3? position = null,
Vector3? rotation = null,
float? horizon = null,
bool? standing = null,
bool forceAction = false
) {
//non-high level agents cannot set standing
if (standing != null) {
errorMessage = "Cannot set standing for arm/stretch agent";
actionFinishedEmit(success: false, actionReturn: null, errorMessage: errorMessage);
return;
}

TeleportFull(
position: position,
rotation: rotation,
horizon: horizon,
standing: standing,
forceAction: forceAction
);
}

public override void TeleportFull(
Vector3? position = null,
Vector3? rotation = null,
float? horizon = null,
bool? standing = null,
bool forceAction = false
) {
//non-high level agents cannot set standing
if (standing != null) {
errorMessage = "Cannot set standing for arm/stretch agent";
actionFinishedEmit(success: false, actionReturn: null, errorMessage: errorMessage);
return;
}

//cache old values in case there is a failure
Vector3 oldPosition = transform.position;
Quaternion oldRotation = transform.rotation;
Quaternion oldCameraRotation = m_Camera.transform.localRotation;

try {
base.teleportFull(
position: position,
rotation: rotation,
horizon: horizon,
forceAction: forceAction
);

// add arm value cases
if (!forceAction) {
if (Arm != null && Arm.IsArmColliding()) {
throw new InvalidOperationException(
"Mid Level Arm is actively clipping with some geometry in the environment. TeleportFull fails in this position."
);
} else if (SArm != null && SArm.IsArmColliding()) {
throw new InvalidOperationException(
"Stretch Arm is actively clipping with some geometry in the environment. TeleportFull fails in this position."
);
}
base.assertTeleportedNearGround(targetPosition: position);
}
} catch (InvalidOperationException e) {
transform.position = oldPosition;
transform.rotation = oldRotation;
m_Camera.transform.localRotation = oldCameraRotation;

throw new InvalidOperationException(e.Message);
}

actionFinished(success: true);
}

/*
Rotates the wrist (in a relative fashion) given some input
pitch, yaw, and roll offsets. Easiest to see how this works by
Expand Down
5 changes: 3 additions & 2 deletions unity/Assets/Scripts/DiscreteHidenSeekAgentController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,11 @@ void Update() {
) && PhysicsController.ReadyForCommand
) {
ServerAction action = new ServerAction();
if (this.PhysicsController.isStanding()) {
bool? wasStanding = this.PhysicsController.isStanding();
if (wasStanding == true) {
action.action = "Crouch";
PhysicsController.ProcessControlCommand(action);
} else {
} else if (wasStanding == false) {
action.action = "Stand";
PhysicsController.ProcessControlCommand(action);
}
Expand Down
Loading

0 comments on commit 37a921a

Please sign in to comment.