Skip to content

Commit

Permalink
Merge pull request #38 from Photoroom/ben/ml-2347-cleaner-way-to-pass…
Browse files Browse the repository at this point in the history
…-bytes-buffers-around

[gopy] Cleaner bytes buffer transfer
  • Loading branch information
blefaudeux authored Nov 14, 2024
2 parents 9bf7ec7 + ca59fae commit d7a4bfc
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 46 deletions.
2 changes: 1 addition & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func main() {
// Define flags
config := datago.GetDatagoConfig()

sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAGO_TEST_FILESYSTEM")}
sourceConfig.PageSize = 10
sourceConfig.Rank = 0
sourceConfig.WorldSize = 1
Expand Down
6 changes: 2 additions & 4 deletions pkg/architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import "context"

// --- Sample data structures - these will be exposed to the Python world ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type LatentPayload struct {
Data []byte
Len int
DataPtr uintptr
Data []byte
Len int
}

type ImagePayload struct {
Expand All @@ -16,7 +15,6 @@ type ImagePayload struct {
Height int // Useful to decode the current payload
Width int
Channels int
DataPtr uintptr
}

type Sample struct {
Expand Down
1 change: 0 additions & 1 deletion pkg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ func (c *DatagoClient) asyncDispatch() {
for _, item := range page.samplesDataPointers {
select {
case <-c.context.Done():
fmt.Println("Metadata fetch goroutine wrapping up")
close(c.chanSampleMetadata)
return
case c.chanSampleMetadata <- item:
Expand Down
1 change: 0 additions & 1 deletion pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pag
for {
select {
case <-ctx.Done():
fmt.Println("Pages fetch goroutine wrapping up")
return

default:
Expand Down
2 changes: 1 addition & 1 deletion pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c *SourceFileSystemConfig) setDefaults() {
c.Rank = 0
c.WorldSize = 1

c.RootPath = os.Getenv("DATAROOM_TEST_FILESYSTEM")
c.RootPath = os.Getenv("DATAGO_TEST_FILESYSTEM")
}

func GetSourceFileSystemConfig() SourceFileSystemConfig {
Expand Down
2 changes: 0 additions & 2 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspect_ratio fl
Height: height,
Width: width,
Channels: channels,
DataPtr: dataPtrFromSlice(img_bytes),
}

return &img_payload, aspect_ratio, nil
Expand Down Expand Up @@ -243,7 +242,6 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result
latents[latent.LatentType] = LatentPayload{
latent_payload.content,
len(latent_payload.content),
dataPtrFromSlice(latent_payload.content),
}
}
}
Expand Down
8 changes: 0 additions & 8 deletions pkg/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,8 @@ package datago

import (
"time"
"unsafe"
)

func dataPtrFromSlice(a []uint8) uintptr {
if len(a) == 0 {
return 0
}
return uintptr(unsafe.Pointer(&a[0]))
}

func exponentialBackoffWait(retries int) {
baseDelay := time.Second
maxDelay := 64 * time.Second
Expand Down
3 changes: 2 additions & 1 deletion python/benchmark_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from go_types import go_array_to_pil_image, go_array_to_numpy
import typer


def benchmark(
source: str = typer.Option("SOURCE", help="The source to test out"),
source: str = typer.Option("DATAGO_TEST_SOURCE", help="The source to test out"),
limit: int = typer.Option(2000, help="The number of samples to test on"),
crop_and_resize: bool = typer.Option(
True, help="Crop and resize the images on the fly"
Expand Down
6 changes: 3 additions & 3 deletions python/benchmark_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def benchmark(
root_path: str = typer.Option(
os.getenv("DATAROOM_TEST_FILESYSTEM", ""), help="The source to test out"
os.getenv("DATAGO_TEST_FILESYSTEM", ""), help="The source to test out"
),
limit: int = typer.Option(2000, help="The number of samples to test on"),
crop_and_resize: bool = typer.Option(
Expand All @@ -35,8 +35,8 @@ def benchmark(
"max_aspect_ratio": 2.0,
"pre_encode_images": False,
},
"prefetch_buffer_size": 256,
"samples_buffer_size": 128,
"prefetch_buffer_size": concurrency * 2,
"samples_buffer_size": concurrency * 2,
"concurrency": concurrency,
"limit": limit,
}
Expand Down
28 changes: 5 additions & 23 deletions python/go_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import io
from typing import Optional
import numpy as np
from datago import go


def uint8_array_to_numpy(go_array):
if go_array.DataPtr == 0:
print("Error: null pointer")
return np.array([], dtype=np.uint8)

# By convention, arrays which are already serialized as jpg or png are not reshaped
# We export them from Go with a Channels dimension of -1 to mark them as dimensionless.
# Anything else is a valid number of channels and will thus lead to a reshape
Expand All @@ -27,25 +24,15 @@ def uint8_array_to_numpy(go_array):

# Wrap the buffer around to create a numpy array. Strangely, shape needs to be passed twice
# This is a zero-copy operation
return np.ctypeslib.as_array(
(ctypes.c_uint8 * length).from_address(go_array.DataPtr), shape=shape
).reshape(shape)
bytes_buffer = bytes(go.Slice_byte(go_array.Data))
return np.frombuffer(bytes_buffer, dtype=np.uint8).reshape(shape)


def go_array_to_numpy(go_array) -> Optional[np.ndarray]:
# Generic numpy-serialized array

if go_array.DataPtr == 0:
# No data in the buffer
return None

np_bytes = np.ctypeslib.as_array(
(ctypes.c_uint8 * go_array.Len).from_address(go_array.DataPtr),
shape=(go_array.Len,),
)
bytes_io = io.BytesIO(np_bytes.tobytes())
bytes_buffer = bytes(go.Slice_byte(go_array.Data))
try:
return np.load(bytes_io, allow_pickle=False)
return np.load(bytes_buffer, allow_pickle=False)
except ValueError:
# Do not try to handle these, return None and we'll handle it in the caller
print("Could not deserialize numpy array")
Expand All @@ -54,11 +41,6 @@ def go_array_to_numpy(go_array) -> Optional[np.ndarray]:

def go_array_to_pil_image(go_array):
# Zero copy conversion of the image buffer from Go to PIL.Image

if go_array.DataPtr == 0:
# No data in the image buffer
return None

np_array = uint8_array_to_numpy(go_array)
if go_array.Channels <= 0:
# Do not try to decode, we have a jpg or png buffer already
Expand Down
2 changes: 1 addition & 1 deletion python/test_datago_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def get_test_source() -> str:
test_source = os.getenv("DATAROOM_TEST_SOURCE", "COYO")
test_source = os.getenv("DATAROOM_TEST_SOURCE")
assert test_source is not None, "Please set DATAROOM_TEST_SOURCE"
return test_source

Expand Down

0 comments on commit d7a4bfc

Please sign in to comment.