rqlite java sdk 對于sqlite-vec 支持的bug
sqlite-vec 查詢返回的distance 是real 類型的,但是rqlite java sdk 對于類型了check,如果沒在代碼里邊的會直接提示異常
解決方法
實際上real 與包含精度的float 類型是類似的,解決方法就比較簡單了,配置兼容就可以了
參考示例代碼
具體需改的地方比較多,可以參考完整pr 代碼
public static int getJdbcType(String rqliteType) {
if (rqliteType == null) {
throw new IllegalArgumentException("type cannot be null");
}
var parts = rqliteType.trim().toUpperCase().split("[(),]");
var rqType = parts[0];
switch (rqType) {
case RQ_INTEGER: return INTEGER;
case RQ_NUMERIC: return NUMERIC;
case RQ_BOOLEAN: return BOOLEAN;
case RQ_TINYINT: return TINYINT;
case RQ_SMALLINT: return SMALLINT;
case RQ_BIGINT: return BIGINT;
case RQ_FLOAT: return FLOAT;
case RQ_DOUBLE: return DOUBLE;
case RQ_TEXT:
case RQ_VARCHAR: return VARCHAR;
case RQ_DATE: return DATE;
case RQ_TIME: return TIME;
case RQ_TIMESTAMP: return TIMESTAMP;
case RQ_DATALINK: return DATALINK;
case RQ_CLOB: return CLOB;
case RQ_NCLOB: return NCLOB;
case RQ_NVARCHAR: return NVARCHAR;
case RQ_BLOB: return BLOB;
case RQ_NULL: return NULL;
case RQ_REAL: return FLOAT; // RQLite uses REAL as an alias for FLOAT
default: return -1;
}
}
說明
以上是一個示例,參考jdbc 操作代碼
package com.dalong;
import java.sql.*;
import java.util.Arrays;
import java.util.UUID;
public class App {
public static void main(String[] args) {
var url = "jdbc:sqlite:http://localhost:8080";
var vector = new int[] {1,2,4};
System.out.println("vector to " + Arrays.toString(vector));
try (Connection conn = DriverManager.getConnection(url)) {
var stmt = conn.createStatement();
var ps = conn.prepareStatement("insert into dalongrong_vec (id, vector, user_id, type, version) values (?,?,?,?,?)");
ps.setString(1, UUID.randomUUID().toString());
ps.setString(2, Arrays.toString(vector));
ps.setString(3, "user123");
ps.setString(4, "text");
ps.setString(5, "v1");
int result = ps.executeUpdate();
System.out.println("Insert result: " + result);
var psv2 = conn.prepareStatement("select id, user_id,vector, vec_distance_cosine(vector,?) distance from dalongrong_vec where vector match ? and type = ? and version = ? and k = 10 order by distance");
psv2.setString(1, Arrays.toString(vector));
psv2.setString(2, Arrays.toString(vector));
psv2.setString(3, "text");
psv2.setString(4, "v1");
var rs = psv2.executeQuery();
while (rs.next()) {
System.out.println("ID: " + rs.getString("id") + ", vector: " + rs.getString("vector") +
"vector: " + rs.getString("user_id") + ", distance:"+ rs.getFloat("distance"));
}
} catch (SQLException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}
浙公網安備 33010602011771號