Skip to content

LIT Dalle-Mini demo. #1606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lit_nlp/examples/dalle_mini/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Dalle_Mini Demo for the Learning Interpretability Tool
=======================================================

This demo showcases how LIT can be used in text-to-generation mode. It is based
on the mini-dalle Mini model
(https://www.piwheels.org/project/dalle-mini/).

You will need a standalone virtual environment for the Python libraries, which
you can set up using the following commands from the root of the LIT repo.

```sh
# Create the virtual environment. You may want to use python3 or python3.10
# depends on how many Python versions you have installed and their aliases.
python -m venv .dalle-mini
source .dalle-mini/bin/activate
# This requirements.txt file will also install the core LIT library deps.
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt
# The LIT web app still needs to be built in the usual way.
(cd ./lit_nlp && yarn && yarn build)
```

Once your virtual environment is setup, you can launch the demo with the
following command.

```sh
python -m lit_nlp.examples.dalle_mini.demo
```
13 changes: 13 additions & 0 deletions lit_nlp/examples/dalle_mini/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Data loaders for dalle-mini model."""

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types


class DallePrompts(lit_dataset.Dataset):

def __init__(self):
self.examples = []

def spec(self) -> lit_types.Spec:
return {"prompt": lit_types.TextSegment()}
98 changes: 98 additions & 0 deletions lit_nlp/examples/dalle_mini/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
r"""Example for dalle-mini demo model.

To run locally with a small number of examples:
python -m lit_nlp.examples.dalle_mini.demo


Then navigate to localhost:5432 to access the demo UI.
"""

from collections.abc import Sequence
import os
import sys
from typing import Optional

from absl import app
from absl import flags
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import layout
from lit_nlp.examples.dalle_mini import data as dalle_data
from lit_nlp.examples.dalle_mini import model as dalle_model


# NOTE: additional flags defined in server_flags.py
_FLAGS = flags.FLAGS
_FLAGS.set_default("development_demo", True)
_FLAGS.set_default("default_layout", "DALLE_LAYOUT")


_MODELS = flags.DEFINE_list("models", ["dalle-mini"], "Models to load")


_MAX_EXAMPLES = flags.DEFINE_integer(
"max_examples",
5,
"Maximum number of examples to load from each evaluation set. Set to None "
"to load the full set.",
)


# Custom frontend layout; see api/layout.py
_modules = layout.LitModuleName
_DALLE_LAYOUT = layout.LitCanonicalLayout(
upper={
"Main": [
_modules.DataTableModule,
_modules.DatapointEditorModule,
]
},
lower={
"Predictions": [
_modules.GeneratedImageModule,
_modules.GeneratedTextModule,
],
},
description="Custom layout for Text to Image models.",
)


CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"DALLE_LAYOUT": _DALLE_LAYOUT}


def get_wsgi_app() -> Optional[dev_server.LitServerType]:
_FLAGS.set_default("server_type", "external")
_FLAGS.set_default("demo_mode", True)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = _FLAGS(sys.argv, known_only=True)
return main(unused)


def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Load models, according to the --models flag.
models = {}
for model_name_or_path in _MODELS.value:
# Ignore path prefix, if using /path/to/<model_name> to load from a
# specific directory rather than the default shortcut.
model_name = os.path.basename(model_name_or_path)
models[model_name] = dalle_model.DalleMiniModel(
model_name=model_name_or_path
)

datasets = {"Prompts": dalle_data.DallePrompts()}

lit_demo = dev_server.Server(
models,
datasets,
layouts=CUSTOM_LAYOUTS,
**server_flags.get_flags(),
)
return lit_demo.serve()


if __name__ == "__main__":
app.run(main)
103 changes: 103 additions & 0 deletions lit_nlp/examples/dalle_mini/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""LIT wrappers for MiniDalleModel."""

from collections.abc import Iterable

from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import image_utils
from min_dalle import MinDalle
import numpy as np
from PIL import Image
import torch


class DalleMiniModel(lit_model.Model):
"""LIT model wrapper for Dalle-Mini Text-to-Image model.

This wrapper simplifies the pipeline using Dalle-Mini for text-to-image
generation.


The basic flow within this model wrapper's predict() function is:


1. Dalle-Mini processes the text prompt.
2. Images are directly generated by Dalle-Mini.
"""

def __init__(
self,
model_name: str = "dalle-mini",
device: str = "cuda", # Use "cuda" for GPU or "cpu" for CPU
predictions: int = 1,
):
super().__init__()
self.model_name = model_name
self.device = device
self.n_predictions = predictions

# Load Dalle-Mini model
self.model = MinDalle(
models_root="./pretrained",
dtype=torch.float32,
device="cuda",
is_mega=True,
is_reusable=True,
)

def max_minibatch_size(self) -> int:
return 8

def predict(
self, inputs: Iterable[lit_types.JsonDict], **unused_kw
) -> Iterable[lit_types.JsonDict]:
"""Generate images based on the input prompts."""

def tensor_to_pil_image(tensor):
img_np = tensor.detach().cpu().numpy()
img_np = np.squeeze(img_np)
if img_np.ndim == 2:
img_np = np.stack([img_np] * 3, axis=-1)
elif img_np.ndim != 3 or img_np.shape[2] != 3:
raise ValueError(
f"Unexpected image shape: {img_np.shape}. Expected (H, W, 3)."
)

img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255
img_np = img_np.clip(0, 255).astype(np.uint8)
return Image.fromarray(img_np)

prompts = [ex["prompt"] for ex in inputs]
images = []
for prompt in prompts:
# Generate images using the model
generated_images = self.model.generate_images(
text=prompt,
seed=-1,
grid_size=4,
is_seamless=False,
temperature=0.5,
top_k=256,
supercondition_factor=32,
is_verbose=False,
)
pil_images = []
for img_tensor in generated_images:
pil_images.append(tensor_to_pil_image(img_tensor))
images.append({
"image": [
image_utils.convert_pil_to_image_str(img) for img in pil_images
],
"prompt": prompt,
})

return images

def input_spec(self):
return {"prompt": lit_types.TextSegment()}

def output_spec(self):
return {
"image": lit_types.ImageBytesList(),
"prompt": lit_types.TextSegment(),
}
19 changes: 19 additions & 0 deletions lit_nlp/examples/dalle_mini/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

-r ../../../requirements.txt

# Dalle-Mini dependencies
min_dalle==0.1.5