Skip to content

Commit 9257176

Browse files
[add] initial support of AI.TENSORGET command (#5)
* [add] initial support of AI.TENSORGET command
1 parent 8e6bb01 commit 9257176

File tree

6 files changed

+167
-19
lines changed

6 files changed

+167
-19
lines changed

src/main/java/com/redislabs/redisai/DataType.java

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ public List<byte[]> toByteArray(Object obj){
2424

2525
@Override
2626
protected Object toObject(List<byte[]> data) {
27-
// TODO Auto-generated method stub
28-
return null;
27+
int [] values = new int[data.size()];
28+
for (int i = 0; i < data.size(); i++) {
29+
values[i] = Integer.parseInt(SafeEncoder.encode(data.get(i)));
30+
}
31+
return values;
2932
}
3033
},
3134
INT64 {
@@ -41,8 +44,11 @@ public List<byte[]> toByteArray(Object obj){
4144

4245
@Override
4346
protected Object toObject(List<byte[]> data) {
44-
// TODO Auto-generated method stub
45-
return null;
47+
long [] values = new long[data.size()];
48+
for (int i = 0; i < data.size(); i++) {
49+
values[i] = Long.parseLong(SafeEncoder.encode(data.get(i)));
50+
}
51+
return values;
4652
}
4753
},
4854
FLOAT {
@@ -58,14 +64,11 @@ public List<byte[]> toByteArray(Object obj){
5864

5965
@Override
6066
protected Object toObject(List<byte[]> data) {
61-
// float[] values = (float[])obj;
62-
// List<byte[]> res = new ArrayList<>(values.length);
63-
// for(byte[] value : data) {
64-
// res.add(Protocol.to(value));
65-
// }
66-
// return res;
67-
// TODO Auto-generated method stub
68-
return null;
67+
float [] values = new float[data.size()];
68+
for (int i = 0; i < data.size(); i++) {
69+
values[i] = Float.parseFloat(SafeEncoder.encode(data.get(i)));
70+
}
71+
return values;
6972
}
7073
},
7174
DOUBLE {
@@ -81,8 +84,11 @@ public List<byte[]> toByteArray(Object obj){
8184

8285
@Override
8386
protected Object toObject(List<byte[]> data) {
84-
// TODO Auto-generated method stub
85-
return null;
87+
double [] values = new double[data.size()];
88+
for (int i = 0; i < data.size(); i++) {
89+
values[i] = Double.parseDouble(SafeEncoder.encode(data.get(i)));
90+
}
91+
return values;
8692
}
8793
},
8894
STRING {
@@ -98,8 +104,7 @@ public List<byte[]> toByteArray(Object obj){
98104

99105
@Override
100106
protected Object toObject(List<byte[]> data) {
101-
// TODO Auto-generated method stub
102-
return null;
107+
return data;
103108
}
104109
},
105110
BOOL {
@@ -142,6 +147,23 @@ protected Object toObject(List<byte[]> data) {
142147
raw = SafeEncoder.encode(this.name());
143148
}
144149

150+
static DataType getDataTypefromString(String dtypeRaw) {
151+
DataType dt = null;
152+
if (dtypeRaw.equals(DataType.INT32.name())){
153+
dt=DataType.INT32;
154+
}
155+
if (dtypeRaw.equals(DataType.INT64.name())){
156+
dt=DataType.INT64;
157+
}
158+
if (dtypeRaw.equals(DataType.FLOAT.name())){
159+
dt=DataType.FLOAT;
160+
}
161+
if (dtypeRaw.equals(DataType.DOUBLE.name())){
162+
dt=DataType.DOUBLE;
163+
}
164+
return dt;
165+
}
166+
145167
protected abstract List<byte[]> toByteArray(Object obj);
146168
protected abstract Object toObject(List<byte[]> data);
147169

src/main/java/com/redislabs/redisai/Keyword.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
public enum Keyword implements ProtocolCommand{
77

8-
TENSOR, VALUES, INPUTS, OUTPUTS, BLOB, SOURCE;
8+
TENSOR, INPUTS, OUTPUTS, META, VALUES, BLOB, SOURCE;
99

1010
private final byte[] raw;
1111

src/main/java/com/redislabs/redisai/RedisAI.java

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import java.nio.file.Files;
66
import java.nio.file.Paths;
77
import java.util.ArrayList;
8+
import java.util.List;
89
import java.util.stream.Collectors;
910

11+
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
1012
import redis.clients.jedis.BinaryClient;
1113
import redis.clients.jedis.Jedis;
1214
import redis.clients.jedis.JedisPool;
@@ -94,7 +96,60 @@ public boolean setTensor(String key, Object tensor, int[] dimensions){
9496
throw new RedisAIException(ex);
9597
}
9698
}
97-
99+
100+
/**
101+
* TS.GET key
102+
*
103+
* @param key
104+
* @return Tensor
105+
*/
106+
public Tensor getTensor(String key) throws JRedisAIRunTimeException {
107+
try (Jedis conn = getConnection()) {
108+
List<?> reply = sendCommand(conn, Command.TENSOR_GET, SafeEncoder.encode(key), Keyword.META.getRaw(), Keyword.VALUES.getRaw() ).getObjectMultiBulkReply();
109+
if(reply.isEmpty()) {
110+
return null;
111+
}
112+
DataType dtype = null;
113+
long[] shape = null;
114+
Object values = null;
115+
Tensor tensor = null;
116+
for (int i = 0; i < reply.size(); i+=2) {
117+
String arrayKey = SafeEncoder.encode((byte[]) reply.get(i));
118+
switch(arrayKey)
119+
{
120+
case "dtype":
121+
String dtypeString = SafeEncoder.encode((byte[]) reply.get(i+1));
122+
dtype = DataType.getDataTypefromString(dtypeString);
123+
if (dtype==null){
124+
throw new JRedisAIRunTimeException("Unrecognized datatype: "+dtypeString);
125+
}
126+
break;
127+
case "shape":
128+
List<Long> shapeResp = (List<Long>)reply.get(i+1);
129+
shape = new long[shapeResp.size()];
130+
for (int j = 0; j < shapeResp.size(); j++) {
131+
shape[j] = shapeResp.get(j);
132+
}
133+
break;
134+
case "values":
135+
if (dtype==null){
136+
throw new JRedisAIRunTimeException("Trying to decode values array without previous datatype info");
137+
}
138+
List<byte[]> valuesEncoded = (List<byte[]>) reply.get(i+1);
139+
values = dtype.toObject(valuesEncoded);
140+
break;
141+
default:
142+
break;
143+
}
144+
}
145+
if (dtype!=null && shape!=null && values!=null){
146+
tensor = new Tensor(dtype,shape,values);
147+
}
148+
return tensor;
149+
}
150+
151+
}
152+
98153
/**
99154
* AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
100155
*/
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.redislabs.redisai;
2+
3+
public class Tensor {
4+
private DataType dataType;
5+
private long[] shape;
6+
private Object values;
7+
8+
public Tensor(DataType dataType, long[] shape, Object values ) {
9+
this.shape = shape;
10+
this.values = values;
11+
this.dataType = dataType;
12+
}
13+
14+
public Object getValues() {
15+
return values;
16+
}
17+
18+
public void setValues(Object values) {
19+
this.values = values;
20+
}
21+
22+
public long[] getShape() {
23+
return shape;
24+
}
25+
26+
public DataType getDataType() {
27+
return dataType;
28+
}
29+
30+
public void setDataType(DataType dataType) {
31+
this.dataType = dataType;
32+
}
33+
34+
35+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.redislabs.redisai.exceptions;
2+
3+
import redis.clients.jedis.exceptions.JedisDataException;
4+
5+
/**
6+
* An instance of JRedisAIRunTimeException is thrown when RedisAI
7+
* encounters a runtime error during command execution.
8+
*/
9+
public class JRedisAIRunTimeException extends JedisDataException {
10+
public JRedisAIRunTimeException(String message) {
11+
super(message);
12+
}
13+
14+
public JRedisAIRunTimeException(Throwable cause) {
15+
super(cause);
16+
}
17+
18+
public JRedisAIRunTimeException(String message, Throwable cause) {
19+
super(message, cause);
20+
}
21+
}

src/test/java/com/redislabs/redisai/RedisAITest.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ public void testRunModel() {
3939
client.setModel("model", Backend.TF, Device.CPU, new String[] {"a", "b"}, new String[] {"mul"}, model);
4040

4141
client.setTensor("a", new float[] {2, 3}, new int[]{2});
42-
client.setTensor("b", new float[] {2, 3}, new int[]{2});
42+
client.setTensor("b", new float[] {3, 5}, new int[]{2});
4343

4444
Assert.assertTrue(client.runModel("model", new String[] {"a", "b"}, new String[] {"c"}));
45+
Tensor tensor = client.getTensor("c");
46+
float[] values = (float[]) tensor.getValues();
47+
float[] expected = new float[] {6, 15};
48+
Assert.assertTrue("Assert same shape of values", values.length==2);
49+
Assert.assertArrayEquals(values,expected, (float) 0.1);
4550
}
4651

4752
@Test
@@ -70,4 +75,14 @@ public void testRunScript() {
7075

7176
Assert.assertTrue(client.runScript("script", "bar", new String[] {"a1", "b1"}, new String[] {"c1"}));
7277
}
78+
79+
@Test
80+
public void testGetTensor() {
81+
Assert.assertTrue(client.setTensor("t1", new float[][] {{1,2},{3,4}}, new int[] {2,2}));
82+
Tensor tensor = client.getTensor("t1");
83+
float[] values = (float[]) tensor.getValues();
84+
Assert.assertTrue("Assert same shape of values", values.length==4);
85+
float[] expected = new float[] {1,2,3,4};
86+
Assert.assertArrayEquals(values,expected, (float) 0.1);
87+
}
7388
}

0 commit comments

Comments
 (0)