@@ -264,6 +264,298 @@ extension TensorView<float>
264
264
void InterlockedCompareExchange(vector< uint , N> index, float compare, float val);
265
265
}
266
266
267
+ interface IDiffTensorWrapper
268
+ {
269
+ __generic < T : __BuiltinFloatingPointType>
270
+ T load_forward(uint offset);
271
+
272
+ __generic < T : __BuiltinFloatingPointType>
273
+ T load_forward_2(uint2 offset);
274
+
275
+ __generic < T : __BuiltinFloatingPointType>
276
+ T load_forward_3(uint3 offset);
277
+
278
+ __generic < T : __BuiltinFloatingPointType>
279
+ T load_forward_4(uint4 offset);
280
+
281
+ __generic < T : __BuiltinFloatingPointType>
282
+ void load_backward(uint offset, T dOut);
283
+
284
+ __generic < T : __BuiltinFloatingPointType>
285
+ void load_backward_2(uint2 offset, T dOut);
286
+
287
+ __generic < T : __BuiltinFloatingPointType>
288
+ void load_backward_3(uint3 offset, T dOut);
289
+
290
+ __generic < T : __BuiltinFloatingPointType>
291
+ void load_backward_4(uint4 offset, T dOut);
292
+
293
+ __generic < T : __BuiltinFloatingPointType>
294
+ void store_forward(uint offset, T dx);
295
+
296
+ __generic < T : __BuiltinFloatingPointType>
297
+ void store_forward_2(uint2 offset, T dx);
298
+
299
+ __generic < T : __BuiltinFloatingPointType>
300
+ void store_forward_3(uint3 offset, T dx);
301
+
302
+ __generic < T : __BuiltinFloatingPointType>
303
+ void store_forward_4(uint4 offset, T dx);
304
+
305
+ __generic < T : __BuiltinFloatingPointType>
306
+ T store_backward(uint offset);
307
+
308
+ __generic < T : __BuiltinFloatingPointType>
309
+ T store_backward_2(uint2 offset);
310
+
311
+ __generic < T : __BuiltinFloatingPointType>
312
+ T store_backward_3(uint3 offset);
313
+
314
+ __generic < T : __BuiltinFloatingPointType>
315
+ T store_backward_4(uint4 offset);
316
+ };
317
+
318
+ struct AtomicAdd : IDiffTensorWrapper
319
+ {
320
+ TensorView< float > diff;
321
+
322
+ __generic < T : __BuiltinFloatingPointType>
323
+ T load_forward(uint i)
324
+ {
325
+ return __realCast< T, float > (diff .load (i));
326
+ }
327
+
328
+ __generic < T : __BuiltinFloatingPointType>
329
+ T load_forward_2(uint2 i)
330
+ {
331
+ return __realCast< T, float > (diff .load (i .x , i .y ));
332
+ }
333
+
334
+ __generic < T : __BuiltinFloatingPointType>
335
+ T load_forward_3(uint3 i)
336
+ {
337
+ return __realCast< T, float > (diff .load (i .x , i .y , i .z ));
338
+ }
339
+
340
+ __generic < T : __BuiltinFloatingPointType>
341
+ T load_forward_4(uint4 i)
342
+ {
343
+ return __realCast< T, float > (diff .load (i .x , i .y , i .z , i .w ));
344
+ }
345
+
346
+ __generic < T : __BuiltinFloatingPointType>
347
+ void load_backward(uint i, T dOut)
348
+ {
349
+ float oldVal;
350
+ diff .InterlockedAdd (i, __realCast< float , T> (dOut), oldVal);
351
+ }
352
+
353
+ __generic < T : __BuiltinFloatingPointType>
354
+ void load_backward_2(uint2 i, T dOut)
355
+ {
356
+ float oldVal;
357
+ diff .InterlockedAdd (i, __realCast< float , T> (dOut), oldVal);
358
+ }
359
+
360
+ __generic < T : __BuiltinFloatingPointType>
361
+ void load_backward_3(uint3 i, T dOut)
362
+ {
363
+ float oldVal;
364
+ diff .InterlockedAdd (i, __realCast< float , T> (dOut), oldVal);
365
+ }
366
+
367
+ __generic < T : __BuiltinFloatingPointType>
368
+ void load_backward_4(uint4 i, T dOut)
369
+ {
370
+ float oldVal;
371
+ diff .InterlockedAdd (i, __realCast< float , T> (dOut), oldVal);
372
+ }
373
+
374
+ __generic < T : __BuiltinFloatingPointType>
375
+ void store_forward(uint i, T dx)
376
+ {
377
+ diff .store (i, __realCast< float , T> (dx));
378
+ }
379
+
380
+ __generic < T : __BuiltinFloatingPointType>
381
+ void store_forward_2(uint2 i, T dx)
382
+ {
383
+ diff .store (i .x , i .y , __realCast< float , T> (dx));
384
+ }
385
+
386
+ __generic < T : __BuiltinFloatingPointType>
387
+ void store_forward_3(uint3 i, T dx)
388
+ {
389
+ diff .store (i .x , i .y , i .z , __realCast< float , T> (dx));
390
+ }
391
+
392
+ __generic < T : __BuiltinFloatingPointType>
393
+ void store_forward_4(uint4 i, T dx)
394
+ {
395
+ diff .store (i .x , i .y , i .z , i .w , __realCast< float , T> (dx));
396
+ }
397
+
398
+ __generic < T : __BuiltinFloatingPointType>
399
+ T store_backward(uint i)
400
+ {
401
+ float oldVal;
402
+ diff .InterlockedExchange (i, (float )0 , oldVal);
403
+ return __realCast< T, float > (oldVal);
404
+ }
405
+
406
+ __generic < T : __BuiltinFloatingPointType>
407
+ T store_backward_2(uint2 i)
408
+ {
409
+ float oldVal;
410
+ diff .InterlockedExchange (i, (float )0 , oldVal);
411
+ return __realCast< T, float > (oldVal);
412
+ }
413
+
414
+ __generic < T : __BuiltinFloatingPointType>
415
+ T store_backward_3(uint3 i)
416
+ {
417
+ float oldVal;
418
+ diff .InterlockedExchange (i, (float )0 , oldVal);
419
+ return __realCast< T, float > (oldVal);
420
+ }
421
+
422
+ __generic < T : __BuiltinFloatingPointType>
423
+ T store_backward_4(uint4 i)
424
+ {
425
+ float oldVal;
426
+ diff .InterlockedExchange (i, (float )0 , oldVal);
427
+ return __realCast< T, float > (oldVal);
428
+ }
429
+ };
430
+
431
+ __generic < T: __BuiltinFloatingPointType = float , A : IDiffTensorWrapper = AtomicAdd>
432
+ struct DiffTensorView
433
+ {
434
+ TensorView< T> primal;
435
+ A diff;
436
+
437
+ uint size(uint i)
438
+ {
439
+ return primal .size (i);
440
+ }
441
+
442
+ [BackwardDerivative(load_backward)]
443
+ [ForwardDerivative(load_forward)]
444
+ T load(uint i) { return primal .load (i); }
445
+
446
+ [BackwardDerivative(load_backward)]
447
+ [ForwardDerivative(load_forward)]
448
+ T load(uint2 i) { return primal .load (i .x , i .y ); }
449
+
450
+ [BackwardDerivative(load_backward)]
451
+ [ForwardDerivative(load_forward)]
452
+ T load(uint3 i) { return primal .load (i .x , i .y , i .z ); }
453
+
454
+ [BackwardDerivative(load_backward)]
455
+ [ForwardDerivative(load_forward)]
456
+ T load(uint4 i) { return primal .load (i .x , i .y , i .z , i .w ); }
457
+
458
+ DifferentialPair< T> load_forward(uint x)
459
+ {
460
+ return diffPair(primal .load (x), reinterpret< T .Differential , T> (diff .load_forward < T> (x)));
461
+ }
462
+
463
+ DifferentialPair< T> load_forward(uint2 x)
464
+ {
465
+ return diffPair(primal .load (x .x , x .y ), reinterpret< T .Differential , T> (diff .load_forward_2 < T> (x)));
466
+ }
467
+
468
+ DifferentialPair< T> load_forward(uint3 x)
469
+ {
470
+ return diffPair(primal .load (x .x , x .y , x .z ), reinterpret< T .Differential , T> (diff .load_forward_3 < T> (x)));
471
+ }
472
+
473
+ DifferentialPair< T> load_forward(uint4 x)
474
+ {
475
+ return diffPair(primal .load (x .x , x .y , x .z , x .w ), reinterpret< T .Differential , T> (diff .load_forward_4 < T> (x)));
476
+ }
477
+
478
+ void load_backward(uint x, T .Differential dOut)
479
+ {
480
+ diff .load_backward < T> (x, reinterpret< T, T .Differential > (dOut));
481
+ }
482
+
483
+ void load_backward(uint2 x, T .Differential dOut)
484
+ {
485
+ diff .load_backward_2 < T> (x, reinterpret< T, T .Differential > (dOut));
486
+ }
487
+
488
+ void load_backward(uint3 x, T .Differential dOut)
489
+ {
490
+ diff .load_backward_3 < T> (x, reinterpret< T, T .Differential > (dOut));
491
+ }
492
+
493
+ void load_backward(uint4 x, T .Differential dOut)
494
+ {
495
+ diff .load_backward_4 < T> (x, reinterpret< T, T .Differential > (dOut));
496
+ }
497
+
498
+ [BackwardDerivative(store_backward)]
499
+ [ForwardDerivative(store_forward)]
500
+ void store(uint x, T val) { primal .store (x, val); }
501
+
502
+ [BackwardDerivative(store_backward)]
503
+ [ForwardDerivative(store_forward)]
504
+ void store(uint2 x, T val) { primal .store (x .x , x .y , val); }
505
+
506
+ [BackwardDerivative(store_backward)]
507
+ [ForwardDerivative(store_forward)]
508
+ void store(uint3 x, T val) { primal .store (x .x , x .y , x .z , val); }
509
+
510
+ [BackwardDerivative(store_backward)]
511
+ [ForwardDerivative(store_forward)]
512
+ void store(uint4 x, T val) { primal .store (x .x , x .y , x .z , x .w , val); }
513
+
514
+ void store_forward(uint x, DifferentialPair< T> dpval)
515
+ {
516
+ primal .store (x, dpval .p );
517
+ diff .store_forward < T> (x, reinterpret< T, T .Differential > (dpval .d ));
518
+ }
519
+
520
+ void store_forward(uint2 x, DifferentialPair< T> dpval)
521
+ {
522
+ primal .store (x .x , x .y , dpval .p );
523
+ diff .store_forward_2 < T> (x, reinterpret< T, T .Differential > (dpval .d ));
524
+ }
525
+
526
+ void store_forward(uint3 x, DifferentialPair< T> dpval)
527
+ {
528
+ primal .store (x .x , x .y , x .z , dpval .p );
529
+ diff .store_forward_3 < T> (x, reinterpret< T, T .Differential > (dpval .d ));
530
+ }
531
+
532
+ void store_forward(uint4 x, DifferentialPair< T> dpval)
533
+ {
534
+ primal .store (x .x , x .y , x .z , x .w , dpval .p );
535
+ diff .store_forward_4 < T> (x, reinterpret< T, T .Differential > (dpval .d ));
536
+ }
537
+
538
+ void store_backward(uint x, inout DifferentialPair< T> dpval)
539
+ {
540
+ dpval = diffPair(dpval .p , reinterpret< T .Differential , T> (diff .store_backward < T> (x)));
541
+ }
542
+
543
+ void store_backward(uint2 x, inout DifferentialPair< T> dpval)
544
+ {
545
+ dpval = diffPair(dpval .p , reinterpret< T .Differential , T> (diff .store_backward_2 < T> (x)));
546
+ }
547
+
548
+ void store_backward(uint3 x, inout DifferentialPair< T> dpval)
549
+ {
550
+ dpval = diffPair(dpval .p , reinterpret< T .Differential , T> (diff .store_backward_3 < T> (x)));
551
+ }
552
+
553
+ void store_backward(uint4 x, inout DifferentialPair< T> dpval)
554
+ {
555
+ dpval = diffPair(dpval .p , reinterpret< T .Differential , T> (diff .store_backward_4 < T> (x)));
556
+ }
557
+ };
558
+
267
559
/// Represents the handle of a Torch tensor object.
268
560
__generic < T>
269
561
__intrinsic_type($(kIROp_TorchTensorType ))
0 commit comments