Skip to content

Commit 739c3a7

Browse files
Added [AutoPyBindCUDA] for automatic kernel binding + [PyExport] for exporting type information (shader-slang#3209)
* Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests
1 parent 359fdc9 commit 739c3a7

35 files changed

+1387
-46
lines changed

source/slang/core.meta.slang

+6
Original file line numberDiff line numberDiff line change
@@ -2263,6 +2263,12 @@ attribute_syntax [CudaHost] : CudaHostAttribute;
22632263
__attributeTarget(FuncDecl)
22642264
attribute_syntax [CudaKernel] : CudaKernelAttribute;
22652265

2266+
__attributeTarget(FuncDecl)
2267+
attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute;
2268+
2269+
__attributeTarget(AggTypeDecl)
2270+
attribute_syntax [PyExport(name: String)] : PyExportAttribute;
2271+
22662272
__attributeTarget(InterfaceDecl)
22672273
attribute_syntax [COM(guid: String)] : ComInterfaceAttribute;
22682274

source/slang/diff.meta.slang

+292
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,298 @@ extension TensorView<float>
264264
void InterlockedCompareExchange(vector<uint, N> index, float compare, float val);
265265
}
266266

267+
interface IDiffTensorWrapper
268+
{
269+
__generic<T : __BuiltinFloatingPointType>
270+
T load_forward(uint offset);
271+
272+
__generic<T : __BuiltinFloatingPointType>
273+
T load_forward_2(uint2 offset);
274+
275+
__generic<T : __BuiltinFloatingPointType>
276+
T load_forward_3(uint3 offset);
277+
278+
__generic<T : __BuiltinFloatingPointType>
279+
T load_forward_4(uint4 offset);
280+
281+
__generic<T : __BuiltinFloatingPointType>
282+
void load_backward(uint offset, T dOut);
283+
284+
__generic<T : __BuiltinFloatingPointType>
285+
void load_backward_2(uint2 offset, T dOut);
286+
287+
__generic<T : __BuiltinFloatingPointType>
288+
void load_backward_3(uint3 offset, T dOut);
289+
290+
__generic<T : __BuiltinFloatingPointType>
291+
void load_backward_4(uint4 offset, T dOut);
292+
293+
__generic<T : __BuiltinFloatingPointType>
294+
void store_forward(uint offset, T dx);
295+
296+
__generic<T : __BuiltinFloatingPointType>
297+
void store_forward_2(uint2 offset, T dx);
298+
299+
__generic<T : __BuiltinFloatingPointType>
300+
void store_forward_3(uint3 offset, T dx);
301+
302+
__generic<T : __BuiltinFloatingPointType>
303+
void store_forward_4(uint4 offset, T dx);
304+
305+
__generic<T : __BuiltinFloatingPointType>
306+
T store_backward(uint offset);
307+
308+
__generic<T : __BuiltinFloatingPointType>
309+
T store_backward_2(uint2 offset);
310+
311+
__generic<T : __BuiltinFloatingPointType>
312+
T store_backward_3(uint3 offset);
313+
314+
__generic<T : __BuiltinFloatingPointType>
315+
T store_backward_4(uint4 offset);
316+
};
317+
318+
struct AtomicAdd : IDiffTensorWrapper
319+
{
320+
TensorView<float> diff;
321+
322+
__generic<T : __BuiltinFloatingPointType>
323+
T load_forward(uint i)
324+
{
325+
return __realCast<T, float>(diff.load(i));
326+
}
327+
328+
__generic<T : __BuiltinFloatingPointType>
329+
T load_forward_2(uint2 i)
330+
{
331+
return __realCast<T, float>(diff.load(i.x, i.y));
332+
}
333+
334+
__generic<T : __BuiltinFloatingPointType>
335+
T load_forward_3(uint3 i)
336+
{
337+
return __realCast<T, float>(diff.load(i.x, i.y, i.z));
338+
}
339+
340+
__generic<T : __BuiltinFloatingPointType>
341+
T load_forward_4(uint4 i)
342+
{
343+
return __realCast<T, float>(diff.load(i.x, i.y, i.z, i.w));
344+
}
345+
346+
__generic<T : __BuiltinFloatingPointType>
347+
void load_backward(uint i, T dOut)
348+
{
349+
float oldVal;
350+
diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
351+
}
352+
353+
__generic<T : __BuiltinFloatingPointType>
354+
void load_backward_2(uint2 i, T dOut)
355+
{
356+
float oldVal;
357+
diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
358+
}
359+
360+
__generic<T : __BuiltinFloatingPointType>
361+
void load_backward_3(uint3 i, T dOut)
362+
{
363+
float oldVal;
364+
diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
365+
}
366+
367+
__generic<T : __BuiltinFloatingPointType>
368+
void load_backward_4(uint4 i, T dOut)
369+
{
370+
float oldVal;
371+
diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
372+
}
373+
374+
__generic<T : __BuiltinFloatingPointType>
375+
void store_forward(uint i, T dx)
376+
{
377+
diff.store(i, __realCast<float, T>(dx));
378+
}
379+
380+
__generic<T : __BuiltinFloatingPointType>
381+
void store_forward_2(uint2 i, T dx)
382+
{
383+
diff.store(i.x, i.y, __realCast<float, T>(dx));
384+
}
385+
386+
__generic<T : __BuiltinFloatingPointType>
387+
void store_forward_3(uint3 i, T dx)
388+
{
389+
diff.store(i.x, i.y, i.z, __realCast<float, T>(dx));
390+
}
391+
392+
__generic<T : __BuiltinFloatingPointType>
393+
void store_forward_4(uint4 i, T dx)
394+
{
395+
diff.store(i.x, i.y, i.z, i.w, __realCast<float, T>(dx));
396+
}
397+
398+
__generic<T : __BuiltinFloatingPointType>
399+
T store_backward(uint i)
400+
{
401+
float oldVal;
402+
diff.InterlockedExchange(i, (float)0, oldVal);
403+
return __realCast<T, float>(oldVal);
404+
}
405+
406+
__generic<T : __BuiltinFloatingPointType>
407+
T store_backward_2(uint2 i)
408+
{
409+
float oldVal;
410+
diff.InterlockedExchange(i, (float)0, oldVal);
411+
return __realCast<T, float>(oldVal);
412+
}
413+
414+
__generic<T : __BuiltinFloatingPointType>
415+
T store_backward_3(uint3 i)
416+
{
417+
float oldVal;
418+
diff.InterlockedExchange(i, (float)0, oldVal);
419+
return __realCast<T, float>(oldVal);
420+
}
421+
422+
__generic<T : __BuiltinFloatingPointType>
423+
T store_backward_4(uint4 i)
424+
{
425+
float oldVal;
426+
diff.InterlockedExchange(i, (float)0, oldVal);
427+
return __realCast<T, float>(oldVal);
428+
}
429+
};
430+
431+
__generic<T: __BuiltinFloatingPointType = float, A : IDiffTensorWrapper = AtomicAdd>
432+
struct DiffTensorView
433+
{
434+
TensorView<T> primal;
435+
A diff;
436+
437+
uint size(uint i)
438+
{
439+
return primal.size(i);
440+
}
441+
442+
[BackwardDerivative(load_backward)]
443+
[ForwardDerivative(load_forward)]
444+
T load(uint i) { return primal.load(i); }
445+
446+
[BackwardDerivative(load_backward)]
447+
[ForwardDerivative(load_forward)]
448+
T load(uint2 i) { return primal.load(i.x, i.y); }
449+
450+
[BackwardDerivative(load_backward)]
451+
[ForwardDerivative(load_forward)]
452+
T load(uint3 i) { return primal.load(i.x, i.y, i.z); }
453+
454+
[BackwardDerivative(load_backward)]
455+
[ForwardDerivative(load_forward)]
456+
T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); }
457+
458+
DifferentialPair<T> load_forward(uint x)
459+
{
460+
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
461+
}
462+
463+
DifferentialPair<T> load_forward(uint2 x)
464+
{
465+
return diffPair(primal.load(x.x, x.y), reinterpret<T.Differential, T>(diff.load_forward_2<T>(x)));
466+
}
467+
468+
DifferentialPair<T> load_forward(uint3 x)
469+
{
470+
return diffPair(primal.load(x.x, x.y, x.z), reinterpret<T.Differential, T>(diff.load_forward_3<T>(x)));
471+
}
472+
473+
DifferentialPair<T> load_forward(uint4 x)
474+
{
475+
return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret<T.Differential, T>(diff.load_forward_4<T>(x)));
476+
}
477+
478+
void load_backward(uint x, T.Differential dOut)
479+
{
480+
diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
481+
}
482+
483+
void load_backward(uint2 x, T.Differential dOut)
484+
{
485+
diff.load_backward_2<T>(x, reinterpret<T, T.Differential>(dOut));
486+
}
487+
488+
void load_backward(uint3 x, T.Differential dOut)
489+
{
490+
diff.load_backward_3<T>(x, reinterpret<T, T.Differential>(dOut));
491+
}
492+
493+
void load_backward(uint4 x, T.Differential dOut)
494+
{
495+
diff.load_backward_4<T>(x, reinterpret<T, T.Differential>(dOut));
496+
}
497+
498+
[BackwardDerivative(store_backward)]
499+
[ForwardDerivative(store_forward)]
500+
void store(uint x, T val) { primal.store(x, val); }
501+
502+
[BackwardDerivative(store_backward)]
503+
[ForwardDerivative(store_forward)]
504+
void store(uint2 x, T val) { primal.store(x.x, x.y, val); }
505+
506+
[BackwardDerivative(store_backward)]
507+
[ForwardDerivative(store_forward)]
508+
void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); }
509+
510+
[BackwardDerivative(store_backward)]
511+
[ForwardDerivative(store_forward)]
512+
void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); }
513+
514+
void store_forward(uint x, DifferentialPair<T> dpval)
515+
{
516+
primal.store(x, dpval.p);
517+
diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
518+
}
519+
520+
void store_forward(uint2 x, DifferentialPair<T> dpval)
521+
{
522+
primal.store(x.x, x.y, dpval.p);
523+
diff.store_forward_2<T>(x, reinterpret<T, T.Differential>(dpval.d));
524+
}
525+
526+
void store_forward(uint3 x, DifferentialPair<T> dpval)
527+
{
528+
primal.store(x.x, x.y, x.z, dpval.p);
529+
diff.store_forward_3<T>(x, reinterpret<T, T.Differential>(dpval.d));
530+
}
531+
532+
void store_forward(uint4 x, DifferentialPair<T> dpval)
533+
{
534+
primal.store(x.x, x.y, x.z, x.w, dpval.p);
535+
diff.store_forward_4<T>(x, reinterpret<T, T.Differential>(dpval.d));
536+
}
537+
538+
void store_backward(uint x, inout DifferentialPair<T> dpval)
539+
{
540+
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
541+
}
542+
543+
void store_backward(uint2 x, inout DifferentialPair<T> dpval)
544+
{
545+
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_2<T>(x)));
546+
}
547+
548+
void store_backward(uint3 x, inout DifferentialPair<T> dpval)
549+
{
550+
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_3<T>(x)));
551+
}
552+
553+
void store_backward(uint4 x, inout DifferentialPair<T> dpval)
554+
{
555+
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_4<T>(x)));
556+
}
557+
};
558+
267559
/// Represents the handle of a Torch tensor object.
268560
__generic<T>
269561
__intrinsic_type($(kIROp_TorchTensorType))

source/slang/slang-ast-modifier.h

+12
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,18 @@ class CudaHostAttribute : public Attribute
11341134
SLANG_AST_CLASS(CudaHostAttribute)
11351135
};
11361136

1137+
class AutoPyBindCudaAttribute : public Attribute
1138+
{
1139+
SLANG_AST_CLASS(AutoPyBindCudaAttribute)
1140+
};
1141+
1142+
class PyExportAttribute : public Attribute
1143+
{
1144+
SLANG_AST_CLASS(PyExportAttribute)
1145+
1146+
String name;
1147+
};
1148+
11371149
class PreferRecomputeAttribute : public Attribute
11381150
{
11391151
SLANG_AST_CLASS(PreferRecomputeAttribute)

source/slang/slang-check-modifier.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,19 @@ namespace Slang
757757

758758
knownBuiltinAttr->name = name;
759759
}
760+
else if (auto pyExportAttr = as<PyExportAttribute>(attr))
761+
{
762+
// Check name string.
763+
SLANG_ASSERT(attr->args.getCount() == 1);
764+
765+
String name;
766+
if(!checkLiteralStringVal(attr->args[0], &name))
767+
{
768+
return false;
769+
}
770+
771+
pyExportAttr->name = name;
772+
}
760773
else
761774
{
762775
if(attr->args.getCount() == 0)

0 commit comments

Comments
 (0)