Skip to content

Commit

Permalink
fix javadoc
Browse files Browse the repository at this point in the history
  • Loading branch information
haifengl committed Mar 30, 2024
1 parent 6ef9c63 commit 89a43ed
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 19 deletions.
2 changes: 2 additions & 0 deletions deep/src/main/java/smile/deep/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train) {
* @param train the training data.
* @param eval optional evaluation data.
* @param checkpoint optional checkpoint file path.
* @param metrics the evaluation metrics.
*/
public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset eval, String checkpoint, Metric... metrics) {
train(); // training mode
Expand Down Expand Up @@ -184,6 +185,7 @@ public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dat
/**
* Evaluates the model accuracy on a test dataset.
* @param dataset the test dataset.
* @param metrics the evaluation metrics.
* @return the accuracy.
*/
public Map<String, Double> eval(Dataset dataset, Metric... metrics) {
Expand Down
7 changes: 5 additions & 2 deletions deep/src/main/java/smile/deep/layer/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ public interface Layer {
*/
Tensor forward(Tensor input);

/** Returns the PyTorch Module object. */
/**
* Returns the PyTorch Module object.
* @return the PyTorch Module object.
*/
Module asTorch();

/**
Expand Down Expand Up @@ -225,7 +228,7 @@ static FullyConnectedLayer hardShrink(int in, int out) {
* @param size the window size.
* @return a convolutional layer.
*/
static Conv2dLayer conv2d(int in, int out, int size, int pool) {
static Conv2dLayer conv2d(int in, int out, int size) {
return new Conv2dLayer(in, out, size, 1, 1, 1, true);
}

Expand Down
8 changes: 6 additions & 2 deletions deep/src/main/java/smile/deep/metric/Metric.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
* @author Haifeng Li
*/
public interface Metric {
/** Returns the name of metric. */
/**
* Returns the name of metric.
* @return the name of metric.
*/
String name();

/**
Expand All @@ -39,9 +42,10 @@ public interface Metric {
void update(Tensor output, Tensor target);

/**
* Computes the metric values from the metric state, which are updated by
* Computes the metric value from the metric state, which are updated by
* previous update() calls. The compute frequency can be less than the
* update frequency.
* @return the metric value.
*/
double compute();

Expand Down
29 changes: 24 additions & 5 deletions deep/src/main/java/smile/deep/tensor/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,26 @@ public String toString() {
return value.str().getString();
}

/** Returns true if the device is CUDA. */
/**
* Returns true if the device is CUDA.
* @return true if the device is CUDA.
*/
public boolean isCUDA() {
return value.is_cuda();
}

/** Returns true if the device is CPU. */
/**
* Returns true if the device is CPU.
* @return true if the device is CPU.
*/
public boolean isCPU() {
return value.is_cpu();
}

/** Returns true if the device is MPS. */
/**
* Returns true if the device is MPS.
* @return true if the device is MPS.
*/
public boolean isMPS() {
return value.is_mps();
}
Expand All @@ -77,12 +86,18 @@ public void emptyCache() {
}
}

/** Returns the PyTorch device object. */
/**
* Returns the PyTorch device object.
* @return the PyTorch device object.
*/
public org.bytedeco.pytorch.Device asTorch() {
return this.value;
}

/** Returns the preferred (most powerful) device. */
/**
* Returns the preferred (most powerful) device.
* @return the preferred (most powerful) device.
*/
public static Device preferredDevice() {
if (torch.cuda_is_available()) {
return Device.CUDA();
Expand Down Expand Up @@ -145,6 +160,10 @@ public void setDefaultDevice() {
torch.device(value);
}

/**
* Returns the device type.
* @return the device type.
*/
public DeviceType type() {
return type;
}
Expand Down
2 changes: 2 additions & 0 deletions deep/src/main/java/smile/deep/tensor/DeviceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* @author Haifeng Li
*/
public enum DeviceType {
/** CPU */
CPU(torch.DeviceType.CPU),
/** NVIDIA GPU */
CUDA(torch.DeviceType.CUDA),
Expand All @@ -41,6 +42,7 @@ public enum DeviceType {
/**
* Returns the byte value of device type,
* which is compatible with PyTorch.
* @return the byte value of device type.
*/
public byte value() {
return value.value;
Expand Down
53 changes: 43 additions & 10 deletions deep/src/main/java/smile/deep/tensor/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ public Tensor clone() {
return Tensor.of(value.to());
}

/** Returns the PyTorch tensor object. */
/**
* Returns the PyTorch tensor object.
* @return the PyTorch tensor object.
*/
public org.bytedeco.pytorch.Tensor asTorch() {
return this.value;
}
Expand Down Expand Up @@ -148,6 +151,7 @@ public long size(long dim) {
* but with the specified shape. This method returns a view
* if shape is compatible with the current shape.
*
* @param shape the new shape of tensor.
* @return the tensor with the specified shape.
*/
public Tensor reshape(long... shape) {
Expand Down Expand Up @@ -322,22 +326,34 @@ public double getDouble(long... index) {
return x.item_double();
}

/** Returns the int value when the tensor holds a single value. */
/**
* Returns the int value when the tensor holds a single value.
* @return the int value when the tensor holds a single value.
*/
public int toInt() {
return value.item_int();
}

/** Returns the long value when the tensor holds a single value. */
/**
* Returns the long value when the tensor holds a single value.
* @return the long value when the tensor holds a single value.
*/
public long toLong() {
return value.item_long();
}

/** Returns the float value when the tensor holds a single value. */
/**
* Returns the float value when the tensor holds a single value.
* @return the float value when the tensor holds a single value.
*/
public float toFloat() {
return value.item_float();
}

/** Returns the double value when the tensor holds a single value. */
/**
* Returns the double value when the tensor holds a single value.
* @return the double value when the tensor holds a single value.
*/
public double toDouble() {
return value.item_double();
}
Expand Down Expand Up @@ -935,6 +951,7 @@ public Tensor asin_() {

/**
* Returns logical AND of two boolean tensors.
* @param other another tensor.
* @return a new tensor of logical and results.
*/
public Tensor and(Tensor other) {
Expand All @@ -943,6 +960,7 @@ public Tensor and(Tensor other) {

/**
* Returns logical AND of two boolean tensors.
* @param other another tensor.
* @return this tensor.
*/
public Tensor and_(Tensor other) {
Expand All @@ -952,6 +970,7 @@ public Tensor and_(Tensor other) {

/**
* Returns logical OR of two boolean tensors.
* @param other another tensor.
* @return a new tensor of logical and results.
*/
public Tensor or(Tensor other) {
Expand All @@ -960,6 +979,7 @@ public Tensor or(Tensor other) {

/**
* Returns logical OR of two boolean tensors.
* @param other another tensor.
* @return this tensor.
*/
public Tensor or_(Tensor other) {
Expand All @@ -968,7 +988,7 @@ public Tensor or_(Tensor other) {
}

/**
* Randomly zeroes some of the elements of the input tensor
* Randomly zeroes some elements of the input tensor
* with probability p.
*
* @param p the probability of an element to be zeroed.
Expand All @@ -979,7 +999,7 @@ public Tensor dropout(double p) {
}

/**
* Randomly zeroes some of the elements in place
* Randomly zeroes some elements in place
* with probability p.
*
* @param p the probability of an element to be zeroed.
Expand Down Expand Up @@ -1261,7 +1281,11 @@ public Options() {
this.value = new TensorOptions();
}

/** Sets the data type of the elements stored in the tensor. */
/**
* Sets the data type of the elements stored in the tensor.
* @param type the data type.
* @return this options object.
*/
public Options dtype(ScalarType type) {
value = value.dtype(new ScalarTypeOptional(type.value));
return this;
Expand All @@ -1277,13 +1301,22 @@ public Options device(Device device) {
return this;
}

/** Sets strided (dense) or sparse tensor. */
/**
* Sets strided (dense) or sparse tensor.
* @param layout the tensor layout.
* @return this options object.
*/
public Options layout(Layout layout) {
value = value.layout(new LayoutOptional(layout.value));
return this;
}

/** Set true if gradients need to be computed for this Tensor. */
/**
* Set true if gradients need to be computed for this tensor.
* @param required the flag indicating if gradients need to be
* computed for this tensor.
* @return this options object.
*/
public Options requireGradients(boolean required) {
value = value.requires_grad(new BoolOptional(required));
return this;
Expand Down

0 comments on commit 89a43ed

Please sign in to comment.