Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Using [Differentiable] without DiffTensorView #12

Open
bprb opened this issue Nov 15, 2023 · 2 comments
Open

Using [Differentiable] without DiffTensorView #12

bprb opened this issue Nov 15, 2023 · 2 comments

Comments

@bprb
Copy link

bprb commented Nov 15, 2023

Hello!

By mistake, I added [Differentiable] to the square sample but I didn't change TensorView to DiffTensorView.

The result was a square.fwd() that simply calculates square(). I was wondering if there is ever a valid use case for doing this, and if not, if perhaps the compiler could catch this mistake? Because otherwise, if I pass a single input and single output, also by mistake, then there are no errors, but the result is definitely not the forward gradient :)

In detail, this shader...

[AutoPyBindCUDA]
[CUDAKernel]
[Differentiable]
void square(TensorView<float> input, TensorView<float> output)
{
// ...
	output[dispatchIdx.x] = input[dispatchIdx.x] * input[dispatchIdx.x];
}

... compiles without error, but fwd is not the gradient:

__device__ void s_fwd_square_0(TensorView input_2, TensorView output_2)
{
    uint _S14 = (((blockIdx)) * ((blockDim)) + ((threadIdx))).x;
    uint _S15 = ((input_2).sizes[(0U)]);
    if(_S14 >= _S15)
    {
        return;
    }
    float _S16 = ((input_2).load<float>((_S14)));
    float _S17 = ((input_2).load<float>((_S14)));
    (output_2).store<float>((_S14), (_S16 * _S17));   // ?!
    return;
}

extern "C" {
__global__ void __kernel__square_fwd_diff(TensorView _S18, TensorView _S19)
{
    s_fwd_square_0(_S18, _S19);
    return;
}

So in this python code, forward is the same as the squared numbers:

inputs = torch.tensor( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), dtype=torch.float).cuda()

squared = torch.zeros_like(inputs).cuda()
forward = torch.ones_like(inputs).cuda()

m.square(input=inputs, output=squared).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))
m.square.fwd(input=inputs, output=forward).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))

print(inputs)
print(squared)
print(forward)

Output:

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='cuda:0')
tensor([  1.,   4.,   9.,  16.,  25.,  36.,  49.,  64.,  81., 100.], device='cuda:0')
tensor([  1.,   4.,   9.,  16.,  25.,  36.,  49.,  64.,  81., 100.], device='cuda:0')    # oh noes

Everything is fine with the correct signature in the shader...

[AutoPyBindCUDA]
[CUDAKernel]
[Differentiable]
void square(DiffTensorView<float> input, DiffTensorView<float> output)
{
	uint3 dispatchIdx = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
	if (dispatchIdx.x >= input.size(0))
		return;

	output[dispatchIdx.x] = input[dispatchIdx.x] * input[dispatchIdx.x];
}

... and called as ...

inputs = torch.tensor( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), dtype=torch.float).cuda()
input_grad = torch.ones_like(inputs).cuda()

squared = torch.zeros_like(inputs).cuda()
forward = torch.ones_like(inputs).cuda()

m.square(input=inputs, output=squared).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))
m.square.fwd(input=(inputs, input_grad), output=(squared,forward)).launchRaw(blockSize=(32, 1, 1), gridSize=(64,1,1))

So I was wondering if this user error could be caught? Perhaps it's also useful to add a fwd example to the docs to make it very clear that, just like bwd, it expects a pair?

Thanks!
bert

(Edit: rewrote for clarity)

@bprb bprb changed the title No error, bad output, from broken fwd signature Using [Differentiable] without DiffTensorView Nov 15, 2023
@saipraveenb25
Copy link
Collaborator

Thank you for bringing it to our attention!

This is unfortunately a tricky situation. Since we do want to allow regular TensorView types for passing non-differentiable data, it would be confusing to throw errors/warnings when using TensorView with a [Differentiable] function.

As you suggested, for now, we'll add an example with .fwd() to the documentation to make the interface clearer, i.e., .fwd() computes both the regular output & its derivative

@bprb
Copy link
Author

bprb commented Nov 19, 2023

Hi Sai, thank you for the quick reply! I see, this is a valid thing to do. We'll just have to be careful then :)
Thanks for taking the suggestion about the docs!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants