Skip to content

Commit 46b8ab8

Browse files
authoredOct 24, 2024··
wasm: Add compile target option when creating slang session (shader-slang#5403)
* wasm: Add compile target option when creating slang session Also add a new interface to return spirv code which is binary, because 'std::string ComponentType::getEntryPointCode' is not suitable for returning the binary data. We use a more standard way that wrap the binary data by using emscripten::val as the return type. * Add target of metal
1 parent ee709cf commit 46b8ab8

File tree

3 files changed

+91
-4
lines changed

3 files changed

+91
-4
lines changed
 

‎source/slang-wasm/slang-wasm-bindings.cpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ EMSCRIPTEN_BINDINGS(slang)
1717
"getLastError",
1818
&slang::wgsl::getLastError);
1919

20+
function(
21+
"getCompileTargets",
22+
&slang::wgsl::getCompileTargets,
23+
return_value_policy::take_ownership());
24+
2025
class_<slang::wgsl::GlobalSession>("GlobalSession")
2126
.function(
2227
"createSession",
@@ -40,7 +45,10 @@ EMSCRIPTEN_BINDINGS(slang)
4045
return_value_policy::take_ownership())
4146
.function(
4247
"getEntryPointCode",
43-
&slang::wgsl::ComponentType::getEntryPointCode);
48+
&slang::wgsl::ComponentType::getEntryPointCode)
49+
.function(
50+
"getEntryPointCodeSpirv",
51+
&slang::wgsl::ComponentType::getEntryPointCodeSpirv);
4452

4553
class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
4654
.function(
@@ -59,5 +67,11 @@ EMSCRIPTEN_BINDINGS(slang)
5967

6068
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");
6169

70+
class_<slang::wgsl::CompileTargets>("CompileTargets")
71+
.function(
72+
"findCompileTarget",
73+
&slang::wgsl::CompileTargets::findCompileTarget,
74+
return_value_policy::take_ownership());
75+
6276
register_vector<slang::wgsl::ComponentType*>("ComponentTypeList");
6377
}

‎source/slang-wasm/slang-wasm.cpp

+61-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace wgsl
1414
{
1515

1616
Error g_error;
17+
CompileTargets g_compileTargets;
1718

1819
Error getLastError()
1920
{
@@ -22,6 +23,11 @@ Error getLastError()
2223
return currentError;
2324
}
2425

26+
CompileTargets* getCompileTargets()
27+
{
28+
return &g_compileTargets;
29+
}
30+
2531
GlobalSession* createGlobalSession()
2632
{
2733
IGlobalSession* globalSession = nullptr;
@@ -38,15 +44,41 @@ GlobalSession* createGlobalSession()
3844
return new GlobalSession(globalSession);
3945
}
4046

41-
Session* GlobalSession::createSession()
47+
CompileTargets::CompileTargets()
48+
{
49+
#define MAKE_PAIR(x) { #x, SLANG_##x }
50+
51+
m_compileTargetMap = {
52+
MAKE_PAIR(GLSL),
53+
MAKE_PAIR(HLSL),
54+
MAKE_PAIR(WGSL),
55+
MAKE_PAIR(SPIRV),
56+
MAKE_PAIR(METAL),
57+
};
58+
}
59+
60+
int CompileTargets::findCompileTarget(const std::string& name)
61+
{
62+
auto res = m_compileTargetMap.find(name);
63+
if ( res != m_compileTargetMap.end())
64+
{
65+
return res->second;
66+
}
67+
else
68+
{
69+
return SLANG_TARGET_UNKNOWN;
70+
}
71+
}
72+
73+
Session* GlobalSession::createSession(int compileTarget)
4274
{
4375
ISession* session = nullptr;
4476
{
4577
SessionDesc sessionDesc = {};
4678
sessionDesc.structureSize = sizeof(sessionDesc);
4779
constexpr SlangInt targetCount = 1;
4880
TargetDesc target = {};
49-
target.format = SLANG_WGSL;
81+
target.format = (SlangCompileTarget)compileTarget;
5082
sessionDesc.targets = &target;
5183
sessionDesc.targetCount = targetCount;
5284
SlangResult result = m_interface->createSession(sessionDesc, &session);
@@ -202,5 +234,32 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
202234
return {};
203235
}
204236

237+
// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
238+
// to wrap it and return it to the javascript side.
239+
emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
240+
{
241+
Slang::ComPtr<IBlob> kernelBlob;
242+
Slang::ComPtr<ISlangBlob> diagnosticBlob;
243+
SlangResult result = interface()->getEntryPointCode(
244+
entryPointIndex,
245+
targetIndex,
246+
kernelBlob.writeRef(),
247+
diagnosticBlob.writeRef());
248+
if (result != SLANG_OK)
249+
{
250+
g_error.type = std::string("USER");
251+
g_error.result = result;
252+
g_error.message = std::string(
253+
(char*)diagnosticBlob->getBufferPointer(),
254+
(char*)diagnosticBlob->getBufferPointer() +
255+
diagnosticBlob->getBufferSize());
256+
return {};
257+
}
258+
259+
const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
260+
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
261+
ptr));
262+
}
263+
205264
} // namespace wgsl
206265
} // namespace slang

‎source/slang-wasm/slang-wasm.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include <slang.h>
4+
#include <unordered_map>
5+
#include <emscripten/val.h>
46

57
namespace slang
68
{
@@ -20,6 +22,17 @@ class Error
2022

2123
Error getLastError();
2224

25+
class CompileTargets
26+
{
27+
public:
28+
CompileTargets();
29+
int findCompileTarget(const std::string& name);
30+
private:
31+
std::unordered_map<std::string, SlangCompileTarget> m_compileTargetMap;
32+
};
33+
34+
CompileTargets* getCompileTargets();
35+
2336
class ComponentType
2437
{
2538
public:
@@ -30,6 +43,7 @@ class ComponentType
3043
ComponentType* link();
3144

3245
std::string getEntryPointCode(int entryPointIndex, int targetIndex);
46+
emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);
3347

3448
slang::IComponentType* interface() const {return m_interface;}
3549

@@ -93,7 +107,7 @@ class GlobalSession
93107
GlobalSession(slang::IGlobalSession* interface)
94108
: m_interface(interface) {}
95109

96-
Session* createSession();
110+
Session* createSession(int compileTarget);
97111

98112
slang::IGlobalSession* interface() const {return m_interface;}
99113

0 commit comments

Comments
 (0)
Please sign in to comment.