Skip to content

Commit

Permalink
Various fixes for autodiff and slangpy. (#2876)
Browse files Browse the repository at this point in the history
* Various fixes for autodiff and slangpy.

* Fix cuda code gen for `select`.

* Fix getBuildTagString().

* Fix.

---------

Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
csyonghe and Yong He authored May 10, 2023
1 parent 38ed03a commit ddebd60
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 18 deletions.
16 changes: 16 additions & 0 deletions prelude/slang-cpp-host-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@
# include <stdint.h>
#endif // SLANG_LLVM

#if defined(_MSC_VER)
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
#else
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
#endif

#ifdef __cplusplus
# define SLANG_PRELUDE_EXTERN_C extern "C"
# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
# define SLANG_PRELUDE_EXTERN_C_END }
#else
# define SLANG_PRELUDE_EXTERN_C
# define SLANG_PRELUDE_EXTERN_C_START
# define SLANG_PRELUDE_EXTERN_C_END
#endif

#include "slang-cpp-scalar-intrinsics.h"

Expand Down
17 changes: 17 additions & 0 deletions prelude/slang-cpp-types-core.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@ struct Vector<T, 4>

};

template<typename T, int N>
SLANG_FORCE_INLINE Vector<T, N> _slang_select(Vector<bool, N> condition, Vector<T, N> v0, Vector<T, N> v1)
{
Vector<T, N> result;
for (int i = 0; i < N; i++)
{
result[i] = condition[i] ? v0[i] : v1[i];
}
return result;
}

template<typename T>
SLANG_FORCE_INLINE T _slang_select(bool condition, T v0, T v1)
{
return condition ? v0 : v1;
}

template<typename T, int N>
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
{
Expand Down
32 changes: 32 additions & 0 deletions prelude/slang-cuda-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,43 @@ SLANG_FLOAT_MATRIX_OPS(__half)
#undef SLANG_MATRIX_INT_NEG_OP
#undef SLANG_FLOAT_MATRIX_MOD

#define SLANG_SELECT_IMPL(T, N)\
SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select(bool##N condition, Vector<T, N> v0, Vector<T, N> v1) \
{ \
Vector<T, N> result; \
for (int i = 0; i < N; i++) \
{ \
*_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) ? _slang_vector_get_element(v0, i) : _slang_vector_get_element(v1, i); \
} \
return result; \
}
#define SLANG_SELECT_T(T)\
SLANG_SELECT_IMPL(T, 2)\
SLANG_SELECT_IMPL(T, 3)\
SLANG_SELECT_IMPL(T, 4)

SLANG_SELECT_T(int)
SLANG_SELECT_T(uint)
SLANG_SELECT_T(short)
SLANG_SELECT_T(ushort)
SLANG_SELECT_T(char)
SLANG_SELECT_T(uchar)
SLANG_SELECT_T(float)
SLANG_SELECT_T(double)

template<typename T>
SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1)
{
return condition ? v0 : v1;
}

//
// Half support
//

#if SLANG_CUDA_ENABLE_HALF
SLANG_SELECT_T(__half)

// Convenience functions ushort -> half

SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i) { return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)); }
Expand Down
17 changes: 17 additions & 0 deletions prelude/slang-torch-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@
# include <stdint.h>
#endif // SLANG_LLVM

#if defined(_MSC_VER)
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
#else
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
#endif

#ifdef __cplusplus
# define SLANG_PRELUDE_EXTERN_C extern "C"
# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
# define SLANG_PRELUDE_EXTERN_C_END }
#else
# define SLANG_PRELUDE_EXTERN_C
# define SLANG_PRELUDE_EXTERN_C_START
# define SLANG_PRELUDE_EXTERN_C_END
#endif


#define SLANG_PRELUDE_NAMESPACE

Expand Down
18 changes: 17 additions & 1 deletion source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -246,52 +246,66 @@ __intrinsic_type($(kIROp_TorchTensorType))
struct TorchTensor
{
__intrinsic_op($(kIROp_TorchTensorGetView))
[CudaHost]
TensorView<T> getView();

__target_intrinsic(cuda, "$0.dims()")
__target_intrinsic(cpp, "$0.dims()")
[__readNone]
[CudaHost]
uint dims();

__target_intrinsic(cuda, "$0.size($1)")
__target_intrinsic(cpp, "$0.size($1)")
[__readNone]
[CudaHost]
uint size(uint i);

__target_intrinsic(cuda, "$0.stride($1)")
__target_intrinsic(cpp, "$0.stride($1)")
[__readNone]
[CudaHost]
uint stride(uint i);

__target_intrinsic(cuda, "$0.data_ptr<$G0>()")
__target_intrinsic(cpp, "$0.data_ptr<$G0>()")
[__readNone]
[CudaHost]
Ptr<T> data_ptr();

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> alloc(uint x);

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> alloc(uint x, uint y);

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> alloc(uint x, uint y, uint z);

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> alloc(uint x, uint y, uint z, uint w);

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4);

__intrinsic_op($(kIROp_AllocateTorchTensor))
[CudaHost]
static TorchTensor<T> emptyLike(TorchTensor<T> other);

__target_intrinsic(cpp, "$0.zero_()")
[CudaHost]
void fillZero();

__target_intrinsic(cpp, "$0.fill_($1)")
[CudaHost]
void fillValue(T val);

[CudaHost]
static TorchTensor<T> zerosLike(TorchTensor<T> other)
{
var result = emptyLike(other);
Expand Down Expand Up @@ -854,8 +868,10 @@ T detach(T x);

#define SLANG_SQR(x) ((x)*(x))

#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0))))

// Absolute value
UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut)))
UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut))
// Saturate
UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut))
// frac
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,9 @@ DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can

DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")

DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.")

//
// 5xxxx - Target code generation.
//
Expand Down
1 change: 0 additions & 1 deletion source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,7 +2215,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO

case kIROp_Select:
{

auto prec = getInfo(EmitOp::Conditional);
needClose = maybeEmitParens(outerPrec, prec);

Expand Down
11 changes: 11 additions & 0 deletions source/slang/slang-emit-cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,17 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit(")");
return true;
}
case kIROp_Select:
{
m_writer->emit("_slang_select(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(",");
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
m_writer->emit(")");
return true;
}
}
}

Expand Down
46 changes: 46 additions & 0 deletions source/slang/slang-ir-check-differentiability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,54 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
}
}

bool checkType(IRInst* type)
{
type = unwrapAttributedType(type);
if (as<IRTorchTensorType>(type))
return false;
else if (auto arrayType = as<IRArrayTypeBase>(type))
return checkType(arrayType->getElementType());
else if (auto structType = as<IRStructType>(type))
{
for (auto field : structType->getFields())
{
if (!checkType(field->getFieldType()))
return false;
}
}
return true;
}
void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst)
{
auto outerFuncInst = maybeFindOuterGeneric(funcInst);

if (outerFuncInst->findDecoration<IRCudaHostDecoration>())
return;
if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>())
return;

// This is a kernel function, we don't allow using TorchTensor type here.
for (auto b : funcInst->getBlocks())
{
for (auto inst : b->getChildren())
{
if (!checkType(inst->getDataType()))
{
auto loc = inst->sourceLoc;
if (!loc.isValid())
loc = funcInst->sourceLoc;
sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc);
return;
}

}
}
}

void processFunc(IRGlobalValueWithCode* funcInst)
{
checkForInvalidHostTypeUsage(funcInst);

if (!_isFuncMarkedForAutoDiff(funcInst))
return;
if (!funcInst->getFirstBlock())
Expand Down
28 changes: 12 additions & 16 deletions source/slang/slang-ir-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,6 @@ bool isPtrLikeOrHandleType(IRInst* type)
return true;
if (as<IRPseudoPtrType>(type))
return true;
if (as<IRMeshOutputType>(type))
return true;
if (as<IRHLSLOutputPatchType>(type))
return true;
switch (type->getOp())
{
case kIROp_ComPtrType:
Expand Down Expand Up @@ -824,30 +820,30 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst*
if (!isPtrLikeOrHandleType(type))
return false;

if (root)
{
if (as<IRParameterGroupType>(root->getDataType()))
{
return false;
}
}

switch (root->getOp())
{
case kIROp_GlobalVar:
return true;
case kIROp_GlobalParam:
case kIROp_GlobalConstant:
case kIROp_Var:
case kIROp_Param:
break;
case kIROp_Call:
return true;
default:
// The inst is defined by an unknown inst.
return true;
}

if (root)
{
if (as<IRParameterGroupType>(root->getDataType()))
{
return false;
}
auto addrInstParent = getParentFunc(root);
return (addrInstParent != parentFunc);
}
return false;
auto addrInstParent = getParentFunc(root);
return (addrInstParent != parentFunc);
}

struct GenericChildrenMigrationContextImpl
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ namespace Slang {

const char* getBuildTagString()
{
if (UnownedStringSlice(SLANG_TAG_VERSION) == "unknown")
{
// If the tag is unknown, then we will try to get the timestamp of the shared library
// and use that as the version string, so that we can at least return something
// that uniquely identifies the build.
static String timeStampString = String(SharedLibraryUtils::getSharedLibraryTimestamp((void*)spCreateSession));
return timeStampString.getBuffer();
}
return SLANG_TAG_VERSION;
}

Expand Down

0 comments on commit ddebd60

Please sign in to comment.