-
-
Notifications
You must be signed in to change notification settings - Fork 49
/
collections.go
30 lines (26 loc) · 714 Bytes
/
collections.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package tensor
import "github.com/pkg/errors"
func densesToTensors(a []*Dense) []Tensor {
retVal := make([]Tensor, len(a))
for i, t := range a {
retVal[i] = t
}
return retVal
}
func densesToDenseTensors(a []*Dense) []DenseTensor {
retVal := make([]DenseTensor, len(a))
for i, t := range a {
retVal[i] = t
}
return retVal
}
func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) {
retVal := make([]DenseTensor, len(a))
var ok bool
for i, t := range a {
if retVal[i], ok = t.(DenseTensor); !ok {
return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i)
}
}
return retVal, nil
}