Skip to content

Commit 0432907

Browse files
authored
More wasm binding for playground. (#5420)
1 parent a3276e2 commit 0432907

File tree

4 files changed

+130
-14
lines changed

4 files changed

+130
-14
lines changed

source/slang-wasm/slang-wasm-bindings.cpp

+21-4
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ EMSCRIPTEN_BINDINGS(slang)
4747
"getEntryPointCode",
4848
&slang::wgsl::ComponentType::getEntryPointCode)
4949
.function(
50-
"getEntryPointCodeSpirv",
51-
&slang::wgsl::ComponentType::getEntryPointCodeSpirv);
50+
"getEntryPointCodeBlob",
51+
&slang::wgsl::ComponentType::getEntryPointCodeBlob)
52+
.function(
53+
"getTargetCodeBlob",
54+
&slang::wgsl::ComponentType::getTargetCodeBlob)
55+
.function(
56+
"getTargetCode",
57+
&slang::wgsl::ComponentType::getTargetCode);
5258

5359
class_<slang::wgsl::Module, base<slang::wgsl::ComponentType>>("Module")
5460
.function(
@@ -58,14 +64,25 @@ EMSCRIPTEN_BINDINGS(slang)
5864
.function(
5965
"findAndCheckEntryPoint",
6066
&slang::wgsl::Module::findAndCheckEntryPoint,
61-
return_value_policy::take_ownership());
67+
return_value_policy::take_ownership())
68+
.function(
69+
"getDefinedEntryPoint",
70+
&slang::wgsl::Module::getDefinedEntryPoint,
71+
return_value_policy::take_ownership())
72+
.function(
73+
"getDefinedEntryPointCount",
74+
&slang::wgsl::Module::getDefinedEntryPointCount);
6275

6376
value_object<slang::wgsl::Error>("Error")
6477
.field("type", &slang::wgsl::Error::type)
6578
.field("result", &slang::wgsl::Error::result)
6679
.field("message", &slang::wgsl::Error::message);
6780

68-
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint");
81+
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint")
82+
.function(
83+
"getName",
84+
&slang::wgsl::EntryPoint::getName,
85+
allow_raw_pointers());
6986

7087
class_<slang::wgsl::CompileTargets>("CompileTargets")
7188
.function(

source/slang-wasm/slang-wasm.cpp

+90-6
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,15 @@ Session* GlobalSession::createSession(int compileTarget)
9494
return new Session(session);
9595
}
9696

97-
Module* Session::loadModuleFromSource(const std::string& slangCode)
97+
Module* Session::loadModuleFromSource(const std::string& slangCode, const std::string& name, const std::string& path)
9898
{
9999
Slang::ComPtr<IModule> module;
100100
{
101-
const char * name = "";
102-
const char * path = "";
103101
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
104102
Slang::ComPtr<ISlangBlob> slangCodeBlob = Slang::RawBlob::create(
105103
slangCode.c_str(), slangCode.size());
106104
module = m_interface->loadModuleFromSource(
107-
name, path, slangCodeBlob, diagnosticsBlob.writeRef());
105+
name.c_str(), path.c_str(), slangCodeBlob, diagnosticsBlob.writeRef());
108106
if (!module)
109107
{
110108
g_error.type = std::string("USER");
@@ -161,6 +159,38 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
161159
return new EntryPoint(entryPoint);
162160
}
163161

162+
int Module::getDefinedEntryPointCount()
163+
{
164+
return moduleInterface()->getDefinedEntryPointCount();
165+
}
166+
167+
EntryPoint* Module::getDefinedEntryPoint(int index)
168+
{
169+
if (moduleInterface()->getDefinedEntryPointCount() <= index)
170+
return nullptr;
171+
172+
Slang::ComPtr<IEntryPoint> entryPoint;
173+
{
174+
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
175+
SlangResult result = moduleInterface()->getDefinedEntryPoint(index, entryPoint.writeRef());
176+
if (!SLANG_SUCCEEDED(result))
177+
{
178+
g_error.type = std::string("USER");
179+
g_error.result = result;
180+
181+
if (diagnosticsBlob->getBufferSize())
182+
{
183+
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
184+
g_error.message = std::string(diagnostics);
185+
}
186+
return nullptr;
187+
}
188+
}
189+
190+
return new EntryPoint(entryPoint);
191+
}
192+
193+
164194
ComponentType* Session::createCompositeComponentType(
165195
const std::vector<ComponentType*>& components)
166196
{
@@ -235,9 +265,9 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde
235265
return {};
236266
}
237267

238-
// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val
268+
// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
239269
// to wrap it and return it to the javascript side.
240-
emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex)
270+
emscripten::val ComponentType::getEntryPointCodeBlob(int entryPointIndex, int targetIndex)
241271
{
242272
Slang::ComPtr<IBlob> kernelBlob;
243273
Slang::ComPtr<ISlangBlob> diagnosticBlob;
@@ -262,6 +292,60 @@ emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int t
262292
ptr));
263293
}
264294

295+
std::string ComponentType::getTargetCode(int targetIndex)
296+
{
297+
{
298+
Slang::ComPtr<IBlob> kernelBlob;
299+
Slang::ComPtr<ISlangBlob> diagnosticBlob;
300+
SlangResult result = interface()->getTargetCode(
301+
targetIndex,
302+
kernelBlob.writeRef(),
303+
diagnosticBlob.writeRef());
304+
if (result != SLANG_OK)
305+
{
306+
g_error.type = std::string("USER");
307+
g_error.result = result;
308+
g_error.message = std::string(
309+
(char*)diagnosticBlob->getBufferPointer(),
310+
(char*)diagnosticBlob->getBufferPointer() +
311+
diagnosticBlob->getBufferSize());
312+
return "";
313+
}
314+
std::string targetCode = std::string(
315+
(char*)kernelBlob->getBufferPointer(),
316+
(char*)kernelBlob->getBufferPointer() + kernelBlob->getBufferSize());
317+
return targetCode;
318+
}
319+
320+
return {};
321+
}
322+
323+
// Since result code is binary, we can't return it as a string, we will need to use emscripten::val
324+
// to wrap it and return it to the javascript side.
325+
emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
326+
{
327+
Slang::ComPtr<IBlob> kernelBlob;
328+
Slang::ComPtr<ISlangBlob> diagnosticBlob;
329+
SlangResult result = interface()->getTargetCode(
330+
targetIndex,
331+
kernelBlob.writeRef(),
332+
diagnosticBlob.writeRef());
333+
if (result != SLANG_OK)
334+
{
335+
g_error.type = std::string("USER");
336+
g_error.result = result;
337+
g_error.message = std::string(
338+
(char*)diagnosticBlob->getBufferPointer(),
339+
(char*)diagnosticBlob->getBufferPointer() +
340+
diagnosticBlob->getBufferSize());
341+
return {};
342+
}
343+
344+
const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer();
345+
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(),
346+
ptr));
347+
}
348+
265349
namespace lsp
266350
{
267351
Position translate(Slang::LanguageServerProtocol::Position p)

source/slang-wasm/slang-wasm.h

+11-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class ComponentType
4848
ComponentType* link();
4949

5050
std::string getEntryPointCode(int entryPointIndex, int targetIndex);
51-
emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex);
51+
emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex);
52+
std::string getTargetCode(int targetIndex);
53+
emscripten::val getTargetCodeBlob(int targetIndex);
5254

5355
slang::IComponentType* interface() const {return m_interface;}
5456

@@ -62,9 +64,11 @@ class ComponentType
6264
class EntryPoint : public ComponentType
6365
{
6466
public:
65-
6667
EntryPoint(slang::IEntryPoint* interface) : ComponentType(interface) {}
67-
68+
std::string getName() const
69+
{
70+
return entryPointInterface()->getFunctionReflection()->getName();
71+
}
6872
private:
6973

7074
slang::IEntryPoint* entryPointInterface() const {
@@ -80,6 +84,8 @@ class Module : public ComponentType
8084

8185
EntryPoint* findEntryPointByName(const std::string& name);
8286
EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage);
87+
EntryPoint* getDefinedEntryPoint(int index);
88+
int getDefinedEntryPointCount();
8389

8490
slang::IModule* moduleInterface() const {
8591
return static_cast<slang::IModule*>(interface());
@@ -93,7 +99,8 @@ class Session
9399
Session(slang::ISession* interface)
94100
: m_interface(interface) {}
95101

96-
Module* loadModuleFromSource(const std::string& slangCode);
102+
Module* loadModuleFromSource(
103+
const std::string& slangCode, const std::string& name, const std::string& path);
97104

98105
ComponentType* createCompositeComponentType(
99106
const std::vector<ComponentType*>& components);

source/slang/slang.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -5040,13 +5040,21 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD
50405040
});
50415041
List<RefPtr<ComponentType>> components;
50425042
components.add(this);
5043+
bool entryPointsDiscovered = false;
50435044
for (auto module : modules)
50445045
{
50455046
for (auto entryPoint : module->getEntryPoints())
50465047
{
50475048
components.add(entryPoint);
5049+
entryPointsDiscovered = true;
50485050
}
50495051
}
5052+
// If no entry points were discovered, then we should return nullptr.
5053+
if (!entryPointsDiscovered)
5054+
{
5055+
return nullptr;
5056+
}
5057+
50505058
RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components);
50515059
ComPtr<IComponentType> linkedComponentType;
50525060
SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));

0 commit comments

Comments
 (0)