@@ -10,6 +10,8 @@ import (
1010 "strings"
1111 "time"
1212
13+ "terraform-provider-iterative/iterative/utils"
14+
1315 "github.com/aws/aws-sdk-go-v2/aws"
1416 "github.com/aws/aws-sdk-go-v2/config"
1517 "github.com/aws/aws-sdk-go-v2/service/ec2"
@@ -32,6 +34,8 @@ func ResourceMachineCreate(ctx context.Context, d *schema.ResourceData, m interf
3234 spot := d .Get ("spot" ).(bool )
3335 spotPrice := d .Get ("spot_price" ).(float64 )
3436 instanceProfile := d .Get ("instance_permission_set" ).(string )
37+ subnetId := d .Get ("aws_subnet_id" ).(string )
38+ availabilityZone := GetAvailabilityZone (d .Get ("region" ).(string ))
3539
3640 metadata := map [string ]string {
3741 "Name" : d .Get ("name" ).(string ),
@@ -44,7 +48,6 @@ func ResourceMachineCreate(ctx context.Context, d *schema.ResourceData, m interf
4448 if ami == "" {
4549 ami = "iterative-cml"
4650 }
47-
4851 config , err := awsClient (region )
4952 if err != nil {
5053 return decodeAWSError (region , err )
@@ -188,26 +191,47 @@ func ResourceMachineCreate(ctx context.Context, d *schema.ResourceData, m interf
188191 sgID = * sgDesc .SecurityGroups [0 ].GroupId
189192 vpcID = * sgDesc .SecurityGroups [0 ].VpcId
190193
191- subDesc , err := svc .DescribeSubnets (ctx , & ec2.DescribeSubnetsInput {
194+ // default Subnet selection
195+ subnetOptions := & ec2.DescribeSubnetsInput {
192196 Filters : []types.Filter {
193197 {
194198 Name : aws .String ("vpc-id" ),
195199 Values : []string {vpcID },
196200 },
197201 },
198- })
202+ }
203+ // use availability zone from user
204+ if availabilityZone != "" && subnetId == "" {
205+ subnetOptions .Filters = append (subnetOptions .Filters , types.Filter {
206+ Name : aws .String ("availability-zone" ),
207+ Values : []string {availabilityZone },
208+ })
209+ }
210+ // use exact subnet-id from user
211+ if subnetId != "" {
212+ subnetOptions .Filters = append (subnetOptions .Filters , types.Filter {
213+ Name : aws .String ("subnet-id" ),
214+ Values : []string {subnetId },
215+ })
216+ }
217+ subDesc , err := svc .DescribeSubnets (ctx , subnetOptions )
199218 if err != nil {
200219 return decodeAWSError (region , err )
201220 }
202221 if len (subDesc .Subnets ) == 0 {
203- return errors .New ("no subnets found" )
222+ return errors .New ("no Subnet found" )
204223 }
205224 var subnetID string
206- for _ , subnet := range subDesc .Subnets {
207- if * subnet .AvailableIpAddressCount > 0 && * subnet .MapPublicIpOnLaunch {
208- subnetID = * subnet .SubnetId
209- break
225+ // bypass with user provided ID
226+ if subnetId == "" {
227+ for _ , subnet := range subDesc .Subnets {
228+ if * subnet .AvailableIpAddressCount > 0 && * subnet .MapPublicIpOnLaunch {
229+ subnetID = * subnet .SubnetId
230+ break
231+ }
210232 }
233+ } else {
234+ subnetID = subnetId
211235 }
212236 if subnetID == "" {
213237 return errors .New ("No subnet found with public IPs available or able to create new public IPs on creation" )
@@ -299,7 +323,7 @@ func ResourceMachineCreate(ctx context.Context, d *schema.ResourceData, m interf
299323 MinCount : aws .Int32 (1 ),
300324 MaxCount : aws .Int32 (1 ),
301325 SecurityGroupIds : []string {sgID },
302- SubnetId : aws .String (* subDesc . Subnets [ 0 ]. SubnetId ),
326+ SubnetId : aws .String (subnetID ),
303327 BlockDeviceMappings : blockDeviceMappings ,
304328 TagSpecifications : resourceTagSpecifications (types .ResourceTypeInstance , metadata ),
305329 })
@@ -332,7 +356,13 @@ func ResourceMachineCreate(ctx context.Context, d *schema.ResourceData, m interf
332356 }
333357
334358 instanceDesc := descResult .Reservations [0 ].Instances [0 ]
335- d .Set ("instance_ip" , instanceDesc .PublicIpAddress )
359+ var instanceIP string
360+ if instanceDesc .PublicIpAddress != nil {
361+ instanceIP = * instanceDesc .PublicIpAddress
362+ } else {
363+ instanceIP = * instanceDesc .PrivateIpAddress
364+ }
365+ d .Set ("instance_ip" , instanceIP )
336366 d .Set ("instance_launch_time" , instanceDesc .LaunchTime .Format (time .RFC3339 ))
337367 d .Set ("image" , * imagesRes .Images [0 ].Name )
338368
@@ -387,6 +417,14 @@ func awsClient(region string) (aws.Config, error) {
387417 )
388418}
389419
420+ func GetAvailabilityZone (region string ) string {
421+ lastChar := region [len (region )- 1 ]
422+ if lastChar >= 'a' && lastChar <= 'z' {
423+ return region
424+ }
425+ return ""
426+ }
427+
390428//GetRegion maps region to real cloud regions
391429func GetRegion (region string ) string {
392430 instanceRegions := make (map [string ]string )
@@ -398,7 +436,7 @@ func GetRegion(region string) string {
398436 return val
399437 }
400438
401- return region
439+ return utils . StripAvailabilityZone ( region )
402440}
403441
404442func getInstanceType (instanceType string , instanceGPU string ) string {
0 commit comments