clip : refactor, add image_manipulation and llava_uhd classes (#13011)

* clip : refactor, add `image_manipulation` and `llava_uhd`

* refactor llava-1.6 preprocessing

* simplify logic for llava-1.5

* missing include
This commit is contained in:
Xuan-Son Nguyen 2025-04-19 09:15:45 +02:00 committed by GitHub
parent 6408210082
commit 37b9f0d29d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,6 +27,7 @@
#include <sstream> #include <sstream>
#include <cinttypes> #include <cinttypes>
#include <limits> #include <limits>
#include <array>
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
@ -1680,12 +1681,24 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
return true; return true;
} }
// Linear interpolation between two points // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
inline float clip_lerp(float s, float e, float t) { static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
return s + (e - s) * t; dst.nx = src.nx;
dst.ny = src.ny;
dst.buf.resize(src.buf.size());
// TODO @ngxson : seems like this could be done more efficiently on cgraph
for (size_t i = 0; i < src.buf.size(); ++i) {
int c = i % 3; // rgb
dst.buf[i] = (static_cast<float>(src.buf[i]) / 255.0f - mean[c]) / std[c];
}
} }
// Bilinear resize function
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { // set of tools to manupulate images
// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
struct image_manipulation {
// Bilinear resize function
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
dst.nx = target_width; dst.nx = target_width;
dst.ny = target_height; dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height); dst.buf.resize(3 * target_width * target_height);
@ -1703,40 +1716,25 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta
float y_lerp = py - y_floor; float y_lerp = py - y_floor;
for (int c = 0; c < 3; c++) { for (int c = 0; c < 3; c++) {
float top = clip_lerp( float top = lerp(
static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]), static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
x_lerp x_lerp
); );
float bottom = clip_lerp( float bottom = lerp(
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
x_lerp x_lerp
); );
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(clip_lerp(top, bottom, y_lerp)); dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
} }
} }
} }
}
// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
dst.nx = src.nx;
dst.ny = src.ny;
dst.buf.resize(src.buf.size());
// TODO @ngxson : seems like this could be done more efficiently on cgraph
for (size_t i = 0; i < src.buf.size(); ++i) {
int c = i % 3; // rgb
dst.buf[i] = (static_cast<float>(src.buf[i]) / 255.0f - mean[c]) / std[c];
} }
}
inline int clip(int x, int lower, int upper) { // Bicubic resize function
return std::max(lower, std::min(x, upper)); // part of image will be cropped if the aspect ratio is different
} static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
const int nx = img.nx; const int nx = img.nx;
const int ny = img.ny; const int ny = img.ny;
@ -1797,12 +1795,14 @@ static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int t
} }
return true; return true;
} }
// llava-1.6 type of resize_and_pad (black) // llava-1.6 type of resize_and_pad
static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair<int, int>& target_resolution) { // if the ratio is not 1:1, padding with pad_color will be applied
int target_width = target_resolution.first; // pad_color is single channel, default is 0 (black)
int target_height = target_resolution.second; static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
int target_width = target_resolution.width;
int target_height = target_resolution.height;
float scale_w = static_cast<float>(target_width) / image.nx; float scale_w = static_cast<float>(target_width) / image.nx;
float scale_h = static_cast<float>(target_height) / image.ny; float scale_h = static_cast<float>(target_height) / image.ny;
@ -1818,13 +1818,19 @@ static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &imag
} }
clip_image_u8 resized_image; clip_image_u8 resized_image;
// bilinear_resize(image, resized_image, new_width, new_height);
bicubic_resize(image, resized_image, new_width, new_height); bicubic_resize(image, resized_image, new_width, new_height);
clip_image_u8 padded_image; clip_image_u8 padded_image;
padded_image.nx = target_width; padded_image.nx = target_width;
padded_image.ny = target_height; padded_image.ny = target_height;
padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black padded_image.buf.resize(3 * target_width * target_height);
// Fill the padded image with the fill color
for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
padded_image.buf[i] = pad_color[0];
padded_image.buf[i + 1] = pad_color[1];
padded_image.buf[i + 2] = pad_color[2];
}
// Calculate padding offsets // Calculate padding offsets
int pad_x = (target_width - new_width) / 2; int pad_x = (target_width - new_width) / 2;
@ -1838,26 +1844,223 @@ static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &imag
} }
} }
} }
image_output = std::move(padded_image); dst = std::move(padded_image);
} }
static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
dst.nx = w;
dst.ny = h;
dst.buf.resize(3 * w * h);
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int src_idx = 3 * ((y + i)*image.nx + (x + j));
int dst_idx = 3 * (i*w + j);
dst.buf[dst_idx] = image.buf[src_idx];
dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
}
}
}
private:
static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper));
}
// Linear interpolation between two points
static inline float lerp(float s, float e, float t) {
return s + (e - s) * t;
}
};
/** /**
* implementation of LLaVA-UHD:
* - https://arxiv.org/pdf/2403.11703
* - https://github.com/thunlp/LLaVA-UHD
* - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
*
* overview:
* - an image always have a single overview (downscaled image)
* - an image can have 0 or multiple slices, depending on the image size
* - each slice can then be considered as a separate image
*
* for example:
*
* [overview] --> [slice 1] --> [slice 2]
* | |
* +--> [slice 3] --> [slice 4]
*/
struct llava_uhd {
struct slice_coordinates {
int x;
int y;
clip_image_size size;
};
struct slice_instructions {
clip_image_size overview_size; // size of downscaled image
clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size)
clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices
std::vector<slice_coordinates> slices;
bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
};
static int get_max_slices(struct clip_ctx * ctx) {
if (clip_is_minicpmv(ctx)) {
return 9;
}
return 0;
}
static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
slice_instructions res;
const int patch_size = clip_get_patch_size(ctx);
const int slice_size = clip_get_image_size(ctx);
const int max_slice_nums = get_max_slices(ctx);
const int original_width = original_size.width;
const int original_height = original_size.height;
const float log_ratio = log((float)original_width / original_height);
const float ratio = (float)original_width * original_height / (slice_size * slice_size);
const int multiple = fmin(ceil(ratio), max_slice_nums);
const bool has_slices = (multiple > 1);
const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
if (has_pinpoints) {
// has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
auto refine_size = llava_uhd::select_best_resolution(
ctx->vision_model.hparams.image_grid_pinpoints,
original_size);
res.overview_size = clip_image_size{slice_size, slice_size};
res.refined_size = refine_size;
res.grid_size = clip_image_size{0, 0};
res.padding_refined = true;
for (int y = 0; y < refine_size.height; y += slice_size) {
for (int x = 0; x < refine_size.width; x += slice_size) {
slice_coordinates slice;
slice.x = x;
slice.y = y;
slice.size.width = std::min(slice_size, refine_size.width - x);
slice.size.height = std::min(slice_size, refine_size.height - y);
res.slices.push_back(slice);
if (x == 0) {
res.grid_size.width++;
}
}
res.grid_size.height++;
}
return res;
}
// no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices);
res.overview_size = best_size;
if (!has_slices) {
// skip slicing logic
res.refined_size = clip_image_size{0, 0};
res.grid_size = clip_image_size{0, 0};
} else {
auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio);
auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
res.grid_size = best_grid;
res.refined_size = refine_size;
int width = refine_size.width;
int height = refine_size.height;
int grid_x = int(width / best_grid.width);
int grid_y = int(height / best_grid.height);
for (int patches_y = 0, ic = 0;
patches_y < refine_size.height && ic < best_grid.height;
patches_y += grid_y, ic += 1) {
for (int patches_x = 0, jc = 0;
patches_x < refine_size.width && jc < best_grid.width;
patches_x += grid_x, jc += 1) {
slice_coordinates slice;
slice.x = patches_x;
slice.y = patches_y;
slice.size.width = grid_x;
slice.size.height = grid_y;
res.slices.push_back(slice);
// LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
}
}
}
return res;
}
static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
std::vector<clip_image_u8_ptr> output;
// resize to overview size
clip_image_u8_ptr resized_img(clip_image_u8_init());
image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
output.push_back(std::move(resized_img));
if (inst.slices.empty()) {
// no slices, just return the resized image
return output;
}
// resize to refined size
clip_image_u8_ptr refined_img(clip_image_u8_init());
if (inst.padding_refined) {
image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
} else {
image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
}
// create slices
for (const auto & slice : inst.slices) {
int x = slice.x;
int y = slice.y;
int w = slice.size.width;
int h = slice.size.height;
clip_image_u8_ptr img_slice(clip_image_u8_init());
image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
output.push_back(std::move(img_slice));
}
return output;
}
private:
static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.width;
int height = original_size.height;
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
float r = static_cast<float>(width) / height;
height = static_cast<int>(scale_resolution / std::sqrt(r));
width = static_cast<int>(height * r);
}
clip_image_size res;
res.width = ensure_divide(width, patch_size);
res.height = ensure_divide(height, patch_size);
return res;
}
/**
* Selects the best resolution from a list of possible resolutions based on the original size. * Selects the best resolution from a list of possible resolutions based on the original size.
* *
* @param original_size The original size of the image in the format (width, height). * @param original_size The original size of the image
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. * @param possible_resolutions A list of possible resolutions
* @return The best fit resolution in the format (width, height). * @return The best fit resolution
*/ */
static std::pair<int, int> select_best_resolution(const std::pair<int, int> & original_size, const std::vector<std::pair<int, int>> & possible_resolutions) { static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector<clip_image_size> & possible_resolutions) {
int original_width = original_size.first; int original_width = original_size.width;
int original_height = original_size.second; int original_height = original_size.height;
std::pair<int, int> best_fit; clip_image_size best_fit;
int max_effective_resolution = 0; int max_effective_resolution = 0;
int min_wasted_resolution = std::numeric_limits<int>::max(); int min_wasted_resolution = std::numeric_limits<int>::max();
for (const auto& resolution : possible_resolutions) { for (const auto & resolution : possible_resolutions) {
int width = resolution.first; int width = resolution.width;
int height = resolution.second; int height = resolution.height;
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height); float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
int downscaled_width = static_cast<int>(original_width * scale); int downscaled_width = static_cast<int>(original_width * scale);
int downscaled_height = static_cast<int>(original_height * scale); int downscaled_height = static_cast<int>(original_height * scale);
@ -1872,71 +2075,45 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int> & or
} }
return best_fit; return best_fit;
} }
static std::vector<clip_image_u8_ptr> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { // used by llava 1.6 with custom list of pinpoints
std::vector<clip_image_u8_ptr> patches; static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
int width = image.nx; std::vector<clip_image_size> possible_resolutions;
int height = image.ny; for (size_t i = 0; i < pinpoints.size(); i += 2) {
for (int i = 0; i < height; i += patch_size) { possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
for (int j = 0; j < width; j += patch_size) {
clip_image_u8_ptr patch(clip_image_u8_init());
patch->nx = std::min(patch_size, width - j);
patch->ny = std::min(patch_size, height - i);
patch->buf.resize(3 * patch->nx * patch->ny);
for (int y = 0; y < patch->ny; ++y) {
for (int x = 0; x < patch->nx; ++x) {
for (int c = 0; c < 3; ++c) {
patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
} }
return select_best_resolution(original_size, possible_resolutions);
} }
}
patches.push_back(std::move(patch));
}
}
return patches;
}
static int ensure_divide(int length, int patch_size) { static int ensure_divide(int length, int patch_size) {
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size); return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
}
static std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.first;
int height = original_size.second;
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
float r = static_cast<float>(width) / height;
height = static_cast<int>(scale_resolution / std::sqrt(r));
width = static_cast<int>(height * r);
} }
int best_width = ensure_divide(width, patch_size);
int best_height = ensure_divide(height, patch_size);
return std::make_pair(best_width, best_height);
}
static std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) { static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width, height; int width = original_size.width;
std::tie(width, height) = original_size; int height = original_size.height;
int grid_x, grid_y; int grid_x = grid.width;
std::tie(grid_x, grid_y) = grid; int grid_y = grid.height;
int refine_width = ensure_divide(width, grid_x); int refine_width = ensure_divide(width, grid_x);
int refine_height = ensure_divide(height, grid_y); int refine_height = ensure_divide(height, grid_y);
int grid_width = refine_width / grid_x; clip_image_size grid_size;
int grid_height = refine_height / grid_y; grid_size.width = refine_width / grid_x;
grid_size.height = refine_height / grid_y;
// auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) auto best_grid_size = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale);
auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair int best_grid_width = best_grid_size.width;
int best_grid_width, best_grid_height; int best_grid_height = best_grid_size.height;
std::tie(best_grid_width, best_grid_height) = best_grid_size;
// std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) clip_image_size refine_size;
std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) refine_size.width = best_grid_width * grid_x;
refine_size.height = best_grid_height * grid_y;
return refine_size; return refine_size;
} }
static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
std::vector<int> candidate_split_grids_nums; std::vector<int> candidate_split_grids_nums;
for (int i : {multiple - 1, multiple, multiple + 1}) { for (int i : {multiple - 1, multiple, multiple + 1}) {
if (i == 1 || i > max_slice_nums) { if (i == 1 || i > max_slice_nums) {
@ -1945,124 +2122,62 @@ static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int mul
candidate_split_grids_nums.push_back(i); candidate_split_grids_nums.push_back(i);
} }
std::vector<std::pair<int, int>> candidate_grids; std::vector<clip_image_size> candidate_grids;
for (int split_grids_nums : candidate_split_grids_nums) { for (int split_grids_nums : candidate_split_grids_nums) {
int m = 1; int m = 1;
while (m <= split_grids_nums) { while (m <= split_grids_nums) {
if (split_grids_nums % m == 0) { if (split_grids_nums % m == 0) {
candidate_grids.emplace_back(m, split_grids_nums / m); candidate_grids.push_back(clip_image_size{m, split_grids_nums / m});
} }
++m; ++m;
} }
} }
std::pair<int, int> best_grid{1, 1}; clip_image_size best_grid{1, 1};
float min_error = std::numeric_limits<float>::infinity(); float min_error = std::numeric_limits<float>::infinity();
for (const auto& grid : candidate_grids) { for (const auto& grid : candidate_grids) {
float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
if (error < min_error) { if (error < min_error) {
best_grid = grid; best_grid = grid;
min_error = error; min_error = error;
} }
} }
return best_grid; return best_grid;
}
// inspired from LLaVA-UHD:
// -> https://arxiv.org/pdf/2403.11703
// -> https://github.com/thunlp/LLaVA-UHD
// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
const std::pair<int, int> original_size={img->nx,img->ny};
const int original_width = img->nx;
const int original_height = img->ny;
const float log_ratio = log(1.0*original_width/original_height);
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
const int multiple = fmin(ceil(ratio), max_slice_nums);
std::vector<std::vector<clip_image_u8_ptr>> images;
LOG_DBG("%s: multiple %d\n", __func__, multiple);
images.push_back(std::vector<clip_image_u8_ptr>());
if (multiple <= 1) {
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
clip_image_u8_ptr source_image(clip_image_u8_init());
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
images.back().push_back(std::move(source_image));
} }
else if (multiple > 1) { };
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
clip_image_u8_ptr source_image(clip_image_u8_init());
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
images.back().push_back(std::move(source_image));
std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
clip_image_u8_ptr refine_image(clip_image_u8_init());
bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
// split_to_patches
int width = refine_image->nx;
int height = refine_image->ny;
int grid_x = int(width / best_grid.first);
int grid_y = int(height / best_grid.second);
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
images.push_back(std::vector<clip_image_u8_ptr>());
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
clip_image_u8_ptr patch(clip_image_u8_init());
patch->nx = grid_x;
patch->ny = grid_y;
patch->buf.resize(3 * patch->nx * patch->ny);
for (int y = patches_i; y < patches_i + grid_y; ++y) {
for (int x = patches_j; x < patches_j + grid_x; ++x) {
const int i = 3 * (y * refine_image->nx + x);
const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j));
patch->buf[j] = refine_image->buf[i];
patch->buf[j+1] = refine_image->buf[i+1];
patch->buf[j+2] = refine_image->buf[i+2];
}
}
images.back().push_back(std::move(patch));
}
}
}
return images;
}
// TODO @ngxson : decprecate the load_image_size singleton pattern
int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
const int max_slice_nums=9; const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
const int scale_resolution=448; return inst.grid_size.width;
const int original_width = ctx_clip->load_image_size.width;
const int original_height = ctx_clip->load_image_size.height;
const float log_ratio = log(1.0*original_width/original_height);
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
const int multiple = fmin(ceil(ratio), max_slice_nums);
std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
return best_grid.first;
} }
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found // res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
if (!ctx->has_vision_encoder) {
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
return false;
}
clip_image_size original_size{img->nx, img->ny};
bool pad_to_square = true;
auto & params = ctx->vision_model.hparams;
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
pad_to_square = false;
}
if (clip_is_minicpmv(ctx)) { if (clip_is_minicpmv(ctx)) {
int max_slice_nums = 9; auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
std::vector<std::vector<clip_image_u8_ptr>> imgs = uhd_slice_image(img, max_slice_nums); std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
for (size_t i = 0; i < imgs.size(); ++i) { for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t j = 0; j < imgs[i].size(); ++j) { // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
clip_image_f32_ptr res(clip_image_f32_init()); clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(res)); res_imgs->entries.push_back(std::move(res));
} }
}
return true; return true;
} }
else if (ctx->has_qwen2vl_merger) { else if (ctx->has_qwen2vl_merger) {
@ -2070,7 +2185,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
auto patch_size = clip_get_patch_size(ctx) * 2; auto patch_size = clip_get_patch_size(ctx) * 2;
int nx = ceil((float)img->nx / patch_size) * patch_size; int nx = ceil((float)img->nx / patch_size) * patch_size;
int ny = ceil((float)img->ny / patch_size) * patch_size; int ny = ceil((float)img->ny / patch_size) * patch_size;
bicubic_resize(*img, resized, nx, ny); image_manipulation::bicubic_resize(*img, resized, nx, ny);
clip_image_f32_ptr img_f32(clip_image_f32_init()); clip_image_f32_ptr img_f32(clip_image_f32_init());
// clip_image_f32_ptr res(clip_image_f32_init()); // clip_image_f32_ptr res(clip_image_f32_init());
@ -2082,8 +2197,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
clip_image_u8 resized_image; clip_image_u8 resized_image;
int32_t sz=ctx->vision_model.hparams.image_size; int sz = params.image_size;
bicubic_resize(*img, resized_image,sz,sz); image_manipulation::bicubic_resize(*img, resized_image, sz, sz);
clip_image_f32_ptr img_f32(clip_image_f32_init()); clip_image_f32_ptr img_f32(clip_image_f32_init());
//clip_image_save_to_bmp(resized_image, "resized.bmp"); //clip_image_save_to_bmp(resized_image, "resized.bmp");
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
@ -2091,156 +2206,47 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
return true; return true;
} }
bool pad_to_square = true;
if (!ctx->has_vision_encoder) {
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
return false;
}
auto & params = ctx->vision_model.hparams;
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
pad_to_square = false;
}
// free the previous res_imgs if any set
res_imgs->entries.clear();
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
if (pad_to_square && img->nx != img->ny) {
int longer_side = std::max(img->nx, img->ny); if (pad_to_square) {
// for llava-1.5, we resize image to a square, and pad the shorter side with a background color
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
const int longer_side = std::max(img->nx, img->ny);
temp->nx = longer_side; temp->nx = longer_side;
temp->ny = longer_side; temp->ny = longer_side;
temp->buf.resize(3 * longer_side * longer_side); temp->buf.resize(3 * longer_side * longer_side);
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
// fill with background color // background color in RGB from LLaVA (this is the mean rgb color * 255)
for (size_t i = 0; i < temp->buf.size(); i++) { const std::array<uint8_t, 3> pad_color = {122, 116, 104};
temp->buf[i] = bc[i % 3];
}
// copy from the input image // resize the image to the target_size
for (int y = 0; y < img->ny; y++) { image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
for (int x = 0; x < img->nx; x++) {
const int i = 3 * (y * img->nx + x); clip_image_f32_ptr res(clip_image_f32_init());
const int j = 3 * (y * temp->nx + x); normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
temp->buf[j] = img->buf[i]; res_imgs->entries.push_back(std::move(res));
temp->buf[j+1] = img->buf[i+1]; return true;
temp->buf[j+2] = img->buf[i+2];
} } else if (!params.image_grid_pinpoints.empty()) {
}
} else {
if (!params.image_grid_pinpoints.empty()) {
// "spatial_unpad" with "anyres" processing for llava-1.6 // "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions; auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
}
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
// clip_image_save_to_bmp(*img, "input.bmp");
resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
// clip_image_save_to_bmp(*temp, "resized.bmp");
// visually verify normalized image:
// normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
// {
// clip_image_u8 * temp2 = clip_image_u8_init();
// clip_image_convert_f32_to_u8(*res, *temp2);
// clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
// clip_image_u8_free(temp2);
// }
std::vector<clip_image_u8_ptr> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) for (size_t i = 0; i < imgs.size(); ++i) {
// clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
clip_image_u8_ptr image_original_resize(clip_image_u8_init());
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
patches.insert(patches.begin(), std::move(image_original_resize));
for (auto & patch : patches) {
clip_image_f32_ptr res(clip_image_f32_init()); clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(res)); res_imgs->entries.push_back(std::move(res));
} }
return true; return true;
} else {
temp->nx = img->nx;
temp->ny = img->ny;
temp->buf.resize(img->buf.size());
memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
}
} }
const int nx = temp->nx; GGML_ASSERT(false && "Unknown image preprocessing type");
const int ny = temp->ny;
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
const int nx2 = ctx->vision_model.hparams.image_size;
const int ny2 = ctx->vision_model.hparams.image_size;
clip_image_f32_ptr res(clip_image_f32_init());
res->nx = nx2;
res->ny = ny2;
res->buf.resize(3 * nx2 * ny2);
const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;
const int nx3 = int(nx / scale + 0.5f);
const int ny3 = int(ny / scale + 0.5f);
const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f};
for (int y = 0; y < ny3; y++) {
for (int x = 0; x < nx3; x++) {
for (int c = 0; c < 3; c++) {
// linear interpolation
const float sx = (x + 0.5f) * scale - 0.5f;
const float sy = (y + 0.5f) * scale - 0.5f;
const int x0 = std::max(0, (int)std::floor(sx));
const int y0 = std::max(0, (int)std::floor(sy));
const int x1 = std::min(x0 + 1, nx - 1);
const int y1 = std::min(y0 + 1, ny - 1);
const float dx = sx - x0;
const float dy = sy - y0;
const int j00 = 3 * (y0 * nx + x0) + c;
const int j01 = 3 * (y0 * nx + x1) + c;
const int j10 = 3 * (y1 * nx + x0) + c;
const int j11 = 3 * (y1 * nx + x1) + c;
const float v00 = temp->buf[j00];
const float v01 = temp->buf[j01];
const float v10 = temp->buf[j10];
const float v11 = temp->buf[j11];
const float v0 = v00 * (1.0f - dx) + v01 * dx;
const float v1 = v10 * (1.0f - dx) + v11 * dx;
const float v = v0 * (1.0f - dy) + v1 * dy;
const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
const int i = 3 * (y * nx3 + x) + c;
res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
}
}
}
// {
// clip_image_u8 * temp2 = clip_image_u8_init();
// clip_image_convert_f32_to_u8(*res, *temp2);
// clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp");
// clip_image_u8_free(temp2);
// }
// res_imgs.push_back(res);
res_imgs->entries.push_back(std::move(res));
return true;
} }
ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {