Skip to content

Commit

Permalink
Merge pull request #16 from EleutherAI/reward
Browse files Browse the repository at this point in the history
Reward mod
  • Loading branch information
Rabbidon committed Dec 26, 2022
2 parents 9277755 + 9fe9207 commit 477457e
Show file tree
Hide file tree
Showing 27 changed files with 39 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ log/*

# Python
**/__pycache__/
__pycache__/

newworld/
## Non-static Minetest directories or symlinks to these
Expand Down
Binary file removed 03-12-2022 20-29-12.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-13.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-14.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-15.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-16.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-17.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-18.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-19.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-20.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-21.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-22.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-23.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-24.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-25.png
Binary file not shown.
Binary file removed 03-12-2022 20-29-26.png
Binary file not shown.
Binary file not shown.
8 changes: 4 additions & 4 deletions hacking_testing/minetest_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"SLOT7",
"SLOT8",
# these keys open the inventory/menu
"ESC",
# "ESC",
"INVENTORY",
# "AUX1",
# these keys lead to errors:
Expand Down Expand Up @@ -311,8 +311,8 @@ def _unpack_pb_obs(self, received_obs: str):
pb_obs.width,
3,
)
# TODO receive rewards etc.
rew = 0.0
rew = pb_obs.reward
# TODO receive etc.
done = False
info = {}
return obs, rew, done, info
Expand Down Expand Up @@ -361,7 +361,7 @@ def step(self, action: Dict[str, Any]):
byte_obs = self.socket.recv()
next_obs, rew, done, info = self._unpack_pb_obs(byte_obs)
self.last_obs = next_obs
logging.debug("Received obs: {}".format(next_obs.shape))
logging.debug(f"Received obs - {next_obs.shape}; reward - {rew}")
return next_obs, rew, done, info

def render(self, render_mode: str = "human"):
Expand Down
1 change: 1 addition & 0 deletions hacking_testing/test_loop.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
from minetest_env import Minetest

seed = 42
Expand Down
2 changes: 2 additions & 0 deletions minetest.conf
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
name = MinetestAgent
menu_last_game = minetest
enable_client_modding = true
1 change: 1 addition & 0 deletions proto/client/dumb_outputs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ message OutputObservation {
int32 width = 1;
int32 height = 2;
bytes data = 3;
float reward = 4;
}
11 changes: 11 additions & 0 deletions src/client/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ void Client::loadMods()
// Don't load mods twice.
// If client scripting is disabled by the client, don't load builtin or
// client-provided mods.
g_settings->setBool("enable_client_modding", true); // TODO don't hardcode this
if (m_mods_loaded || !g_settings->getBool("enable_client_modding"))
return;

Expand Down Expand Up @@ -1909,13 +1910,23 @@ OutputObservation Client::getSendableData(core::position2di cursorPosition, bool
const irr::video::SColor color = irr::video::SColor(255, 255, 255, 255);
cursorImage->copyToWithAlpha(image, cursorPosition, sourceRect, color, nullptr, true);
}

float reward = 0.0;
ClientScripting *scr = getScript();
if(scr) {
lua_State *L = scr->getStack();
lua_getglobal(L, "reward");
reward = (float)lua_tonumber(L, lua_gettop(L));
lua_pop(L, 1);
}

auto dim = image->getDimension();
std::string imageData = std::string((char*)image->getData(), image->getImageDataSizeInBytes());
OutputObservation data;
data.set_data(imageData);
data.set_width(dim.Width);
data.set_height(dim.Height);
data.set_reward(reward);
image->drop();
return data;
}
Expand Down
1 change: 1 addition & 0 deletions src/client/recorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ with this program; if not, write to the Free Software Foundation, Inc.,
#include "client/recorder.h"
#include "client/dumb_outputs.pb.h"

// TODO: move OutputObservation creation outside the function
void Recorder::sendDataOut(bool isMenuActive, irr::video::IImage* cursorImage, Client *client, InputHandler *input) {
OutputObservation data = client->getSendableData(input->getMousePos(), isMenuActive, cursorImage);
std::string msg = data.SerializeAsString();
Expand Down
6 changes: 6 additions & 0 deletions src/client/render/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void RenderingCore::initialize()
createPipeline();
}

/*
void RenderingCore::savetex(video::ITexture *texture, std::string filename, video::IVideoDriver* videoDriver) {
video::IImage* image = videoDriver->createImageFromData(
texture->getColorFormat(),
Expand All @@ -55,6 +56,7 @@ void RenderingCore::savetex(video::ITexture *texture, std::string filename, vide
videoDriver->writeImageToFile(image, io::path(filename.c_str()));
texture->unlock();
}
*/

void RenderingCore::draw(video::SColor _skycolor, bool _show_hud, bool _show_minimap,
bool _draw_wield_tool, bool _draw_crosshair)
Expand All @@ -68,26 +70,30 @@ void RenderingCore::draw(video::SColor _skycolor, bool _show_hud, bool _show_min
context.show_hud = _show_hud;
context.show_minimap = _show_minimap;

/*
TextureBuffer *buffer = pipeline->createOwned<TextureBuffer>();
buffer->setTexture(0, v2f(1.0f, 1.0f), "idk_lol", video::ECF_A8R8G8B8);
auto tex = new TextureBufferOutput(buffer, 0);
pipeline->setRenderTarget(tex);
for (auto &step: pipeline->m_pipeline)
step->setRenderTarget(tex);
*/

pipeline->reset(context);
pipeline->run(context);

auto t = std::time(nullptr);
auto tm = *std::localtime(&t);

/*
std::ostringstream oss;
oss << std::put_time(&tm, "%d-%m-%Y %H-%M-%S");
auto s = oss.str();
const std::string out = s + ".png";
savetex(tex->buffer->getTexture(0), out, device->getVideoDriver());
*/
}

v2u32 RenderingCore::getVirtualSize() const
Expand Down
2 changes: 1 addition & 1 deletion src/client/render/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RenderingCore
RenderingCore &operator=(RenderingCore &&) = delete;


void savetex(video::ITexture *texture, std::string name, video::IVideoDriver* videoDriver);
// void savetex(video::ITexture *texture, std::string name, video::IVideoDriver* videoDriver);

void initialize();
void draw(video::SColor _skycolor, bool _show_hud, bool _show_minimap,
Expand Down
21 changes: 11 additions & 10 deletions src/script/cpp_api/s_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,7 @@ class ScriptApiBase : protected LuaHelper {
// Check things that should be set by the builtin mod.
void checkSetByBuiltin();

protected:
friend class LuaABM;
friend class LuaLBM;
friend class InvRef;
friend class ObjectRef;
friend class NodeMetaRef;
friend class ModApiBase;
friend class ModApiEnvMod;
friend class LuaVoxelManip;

// TODO make friend
/*
Subtle edge case with coroutines: If for whatever reason you have a
method in a subclass that's called from existing lua_CFunction
Expand All @@ -140,6 +131,16 @@ class ScriptApiBase : protected LuaHelper {
lua_State* getStack()
{ return m_luastack; }

protected:
friend class LuaABM;
friend class LuaLBM;
friend class InvRef;
friend class ObjectRef;
friend class NodeMetaRef;
friend class ModApiBase;
friend class ModApiEnvMod;
friend class LuaVoxelManip;

// Checks that stack size is sane
void realityCheck();
// Takes an error from lua_pcall and throws it as a LuaError
Expand Down
Binary file removed temp.png
Binary file not shown.

0 comments on commit 477457e

Please sign in to comment.